bgaspra commited on
Commit
ec1fd1e
·
verified ·
1 Parent(s): 02c3e07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -38
app.py CHANGED
@@ -6,45 +6,58 @@ import pandas as pd
6
  from datasets import load_dataset
7
  from sklearn.metrics.pairwise import cosine_similarity
8
  import numpy as np
 
 
9
 
10
  # Load Florence-2 model and processor
11
  model_name = "microsoft/Florence-2-base"
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
14
 
 
15
  model = AutoModelForCausalLM.from_pretrained(
16
  model_name,
17
  torch_dtype=torch_dtype,
18
- trust_remote_code=True
 
19
  ).to(device)
20
  processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
21
 
22
  # Load CivitAI dataset (limited to 1000 samples)
 
23
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k", split="train[:1000]")
24
  df = pd.DataFrame(dataset)
 
25
 
26
  # Create cache for embeddings to improve performance
27
  text_embedding_cache = {}
28
 
29
  def get_image_embedding(image):
30
- inputs = processor(images=image, return_tensors="pt").to(device, torch_dtype)
31
- with torch.no_grad():
32
- outputs = model.get_image_features(**inputs)
33
- return outputs.cpu().numpy()
 
 
 
 
34
 
35
  def get_text_embedding(text):
36
- if text in text_embedding_cache:
37
- return text_embedding_cache[text]
38
-
39
- inputs = processor(text=text, return_tensors="pt").to(device, torch_dtype)
40
- with torch.no_grad():
41
- outputs = model.get_text_features(**inputs)
42
-
43
- embedding = outputs.cpu().numpy()
44
- text_embedding_cache[text] = embedding
45
- return embedding
 
 
 
 
46
 
47
- # Pre-compute text embeddings for all prompts in the dataset
48
  def precompute_embeddings():
49
  print("Pre-computing text embeddings...")
50
  for idx, row in df.iterrows():
@@ -55,21 +68,21 @@ def precompute_embeddings():
55
  print("Finished pre-computing embeddings")
56
 
57
  def find_similar_images(uploaded_image, top_k=5):
58
- # Get embedding for uploaded image
59
  query_embedding = get_image_embedding(uploaded_image)
 
 
60
 
61
- # Calculate similarities with dataset
62
  similarities = []
63
  for idx, row in df.iterrows():
64
  prompt_embedding = get_text_embedding(row['prompt'])
65
- similarity = cosine_similarity(query_embedding, prompt_embedding)[0][0]
66
- similarities.append({
67
- 'similarity': similarity,
68
- 'model': row['Model'],
69
- 'prompt': row['prompt']
70
- })
 
71
 
72
- # Sort by similarity and get top k results
73
  sorted_results = sorted(similarities, key=lambda x: x['similarity'], reverse=True)
74
  top_models = []
75
  top_prompts = []
@@ -94,21 +107,28 @@ def process_image(input_image):
94
  if input_image is None:
95
  return "Please upload an image.", "Please upload an image."
96
 
97
- # Convert to PIL Image if needed
98
- if not isinstance(input_image, Image.Image):
99
- input_image = Image.fromarray(input_image)
100
-
101
- # Get recommendations
102
- recommended_models, recommended_prompts = find_similar_images(input_image)
103
-
104
- # Format output
105
- models_text = "Recommended Models:\n" + "\n".join([f"{i+1}. {model}" for i, model in enumerate(recommended_models)])
106
- prompts_text = "Recommended Prompts:\n" + "\n".join([f"{i+1}. {prompt}" for i, prompt in enumerate(recommended_prompts)])
107
-
108
- return models_text, prompts_text
 
 
 
 
109
 
110
  # Pre-compute embeddings when starting the application
111
- precompute_embeddings()
 
 
 
112
 
113
  # Create Gradio interface
114
  iface = gr.Interface(
 
6
  from datasets import load_dataset
7
  from sklearn.metrics.pairwise import cosine_similarity
8
  import numpy as np
9
+ import warnings
10
+ warnings.filterwarnings('ignore')
11
 
12
  # Load Florence-2 model and processor
13
  model_name = "microsoft/Florence-2-base"
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
16
 
17
+ # Modify model loading to disable flash attention
18
  model = AutoModelForCausalLM.from_pretrained(
19
  model_name,
20
  torch_dtype=torch_dtype,
21
+ trust_remote_code=True,
22
+ use_flash_attention=False # Disable flash attention
23
  ).to(device)
24
  processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
25
 
26
  # Load CivitAI dataset (limited to 1000 samples)
27
+ print("Loading dataset...")
28
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k", split="train[:1000]")
29
  df = pd.DataFrame(dataset)
30
+ print("Dataset loaded successfully!")
31
 
32
  # Create cache for embeddings to improve performance
33
  text_embedding_cache = {}
34
 
35
  def get_image_embedding(image):
36
+ try:
37
+ inputs = processor(images=image, return_tensors="pt").to(device, torch_dtype)
38
+ with torch.no_grad():
39
+ outputs = model.get_image_features(**inputs)
40
+ return outputs.cpu().numpy()
41
+ except Exception as e:
42
+ print(f"Error in get_image_embedding: {str(e)}")
43
+ return None
44
 
45
  def get_text_embedding(text):
46
+ try:
47
+ if text in text_embedding_cache:
48
+ return text_embedding_cache[text]
49
+
50
+ inputs = processor(text=text, return_tensors="pt").to(device, torch_dtype)
51
+ with torch.no_grad():
52
+ outputs = model.get_text_features(**inputs)
53
+
54
+ embedding = outputs.cpu().numpy()
55
+ text_embedding_cache[text] = embedding
56
+ return embedding
57
+ except Exception as e:
58
+ print(f"Error in get_text_embedding: {str(e)}")
59
+ return None
60
 
 
61
  def precompute_embeddings():
62
  print("Pre-computing text embeddings...")
63
  for idx, row in df.iterrows():
 
68
  print("Finished pre-computing embeddings")
69
 
70
  def find_similar_images(uploaded_image, top_k=5):
 
71
  query_embedding = get_image_embedding(uploaded_image)
72
+ if query_embedding is None:
73
+ return [], []
74
 
 
75
  similarities = []
76
  for idx, row in df.iterrows():
77
  prompt_embedding = get_text_embedding(row['prompt'])
78
+ if prompt_embedding is not None:
79
+ similarity = cosine_similarity(query_embedding, prompt_embedding)[0][0]
80
+ similarities.append({
81
+ 'similarity': similarity,
82
+ 'model': row['Model'],
83
+ 'prompt': row['prompt']
84
+ })
85
 
 
86
  sorted_results = sorted(similarities, key=lambda x: x['similarity'], reverse=True)
87
  top_models = []
88
  top_prompts = []
 
107
  if input_image is None:
108
  return "Please upload an image.", "Please upload an image."
109
 
110
+ try:
111
+ if not isinstance(input_image, Image.Image):
112
+ input_image = Image.fromarray(input_image)
113
+
114
+ recommended_models, recommended_prompts = find_similar_images(input_image)
115
+
116
+ if not recommended_models or not recommended_prompts:
117
+ return "Error processing image.", "Error processing image."
118
+
119
+ models_text = "Recommended Models:\n" + "\n".join([f"{i+1}. {model}" for i, model in enumerate(recommended_models)])
120
+ prompts_text = "Recommended Prompts:\n" + "\n".join([f"{i+1}. {prompt}" for i, prompt in enumerate(recommended_prompts)])
121
+
122
+ return models_text, prompts_text
123
+ except Exception as e:
124
+ print(f"Error in process_image: {str(e)}")
125
+ return "Error processing image.", "Error processing image."
126
 
127
  # Pre-compute embeddings when starting the application
128
+ try:
129
+ precompute_embeddings()
130
+ except Exception as e:
131
+ print(f"Error in precompute_embeddings: {str(e)}")
132
 
133
  # Create Gradio interface
134
  iface = gr.Interface(