File size: 6,094 Bytes
f2ca68f
9587045
 
 
 
f2ca68f
935a747
9587045
 
f2ca68f
9587045
 
f2ca68f
c17d729
9587045
c17d729
 
 
 
9587045
 
 
 
 
 
 
 
26d55ba
9587045
 
 
 
 
 
 
c17d729
 
 
 
 
9587045
c17d729
107b2a4
9587045
 
c17d729
 
 
 
 
 
 
 
 
9587045
 
26d55ba
9587045
c17d729
935a747
9587045
c17d729
 
 
 
 
 
 
 
 
 
935a747
9587045
c17d729
 
 
 
 
 
 
 
 
 
 
9587045
 
 
 
 
 
 
935a747
9587045
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
935a747
9587045
 
935a747
9587045
 
c17d729
9587045
 
 
c17d729
ec1fd1e
9587045
 
 
 
 
 
c17d729
 
9587045
c17d729
 
 
 
 
 
9587045
 
 
 
 
 
 
 
 
 
26d55ba
c17d729
9587045
ec1fd1e
9587045
 
 
ec1fd1e
9587045
 
 
c17d729
 
 
 
 
 
 
ec1fd1e
9587045
f2ca68f
9587045
 
 
 
 
 
 
f2ca68f
935a747
9587045
c17d729
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
import pandas as pd
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder

# Load dataset
dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')

# Text preprocessing function with None handling
def preprocess_text(text, max_length=100):
    # Handle None or empty text
    if text is None or not isinstance(text, str):
        text = ""
    
    # Convert text to lowercase and split into words
    words = text.lower().split()
    # Truncate or pad to max_length
    if len(words) > max_length:
        words = words[:max_length]
    else:
        words.extend([''] * (max_length - len(words)))
    return words

class CustomDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])
        
        # Filter out None values from Model column
        valid_indices = [i for i, model in enumerate(dataset['Model']) if model is not None]
        self.valid_dataset = dataset.select(valid_indices)
        
        self.label_encoder = LabelEncoder()
        self.labels = self.label_encoder.fit_transform(self.valid_dataset['Model'])
        
        # Create vocabulary from all prompts
        self.vocab = set()
        for item in self.valid_dataset['prompt']:
            try:
                self.vocab.update(preprocess_text(item))
            except Exception as e:
                print(f"Error processing prompt: {e}")
                continue
        
        # Remove empty string from vocabulary if present
        self.vocab.discard('')
        self.vocab = list(self.vocab)
        self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)}
        
    def __len__(self):
        return len(self.valid_dataset)
    
    def text_to_vector(self, text):
        try:
            words = preprocess_text(text)
            vector = torch.zeros(len(self.vocab))
            for word in words:
                if word in self.word_to_idx:
                    vector[self.word_to_idx[word]] += 1
            return vector
        except Exception as e:
            print(f"Error converting text to vector: {e}")
            return torch.zeros(len(self.vocab))
    
    def __getitem__(self, idx):
        try:
            image = self.transform(self.valid_dataset[idx]['image'])
            text_vector = self.text_to_vector(self.valid_dataset[idx]['prompt'])
            label = self.labels[idx]
            return image, text_vector, label
        except Exception as e:
            print(f"Error getting item at index {idx}: {e}")
            # Return zero tensors as fallback
            return (torch.zeros((3, 224, 224)), 
                   torch.zeros(len(self.vocab)), 
                   0)

# Define CNN for image processing
class ImageModel(nn.Module):
    def __init__(self):
        super(ImageModel, self).__init__()
        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, 512)
        
    def forward(self, x):
        return self.model(x)

# Define MLP for text processing
class TextMLP(nn.Module):
    def __init__(self, vocab_size):
        super(TextMLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(vocab_size, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 512)
        )
        
    def forward(self, x):
        return self.layers(x)

# Combined model
class CombinedModel(nn.Module):
    def __init__(self, vocab_size, num_classes):
        super(CombinedModel, self).__init__()
        self.image_model = ImageModel()
        self.text_model = TextMLP(vocab_size)
        self.fc = nn.Linear(1024, num_classes)
        
    def forward(self, image, text):
        image_features = self.image_model(image)
        text_features = self.text_model(text)
        combined = torch.cat((image_features, text_features), dim=1)
        return self.fc(combined)

# Create dataset instance
print("Creating dataset...")
custom_dataset = CustomDataset(dataset)
print(f"Vocabulary size: {len(custom_dataset.vocab)}")
print(f"Number of valid samples: {len(custom_dataset)}")

# Create model
num_classes = len(custom_dataset.label_encoder.classes_)
model = CombinedModel(len(custom_dataset.vocab), num_classes)

def get_recommendations(image):
    model.eval()
    with torch.no_grad():
        # Process input image
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])
        image_tensor = transform(image).unsqueeze(0)
        
        # Create dummy text vector
        dummy_text = torch.zeros((1, len(custom_dataset.vocab)))
        
        # Get model output
        output = model(image_tensor, dummy_text)
        _, indices = torch.topk(output, 5)
        
        # Get recommended images and their information
        recommendations = []
        for idx in indices[0]:
            try:
                recommended_image = custom_dataset.valid_dataset[idx.item()]['image']
                model_name = custom_dataset.valid_dataset[idx.item()]['Model']
                recommendations.append((recommended_image, f"{model_name}"))
            except Exception as e:
                print(f"Error getting recommendation for index {idx}: {e}")
                continue
        
    return recommendations

# Set up Gradio interface
interface = gr.Interface(
    fn=get_recommendations,
    inputs=gr.Image(type="pil"),
    outputs=gr.Gallery(label="Recommended Images"),
    title="Image Recommendation System",
    description="Upload an image and get similar images with their model names."
)

# Launch the app
if __name__ == "__main__":
    interface.launch()