Reality123b commited on
Commit
ac4d454
·
verified ·
1 Parent(s): 55aa708

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -48
app.py CHANGED
@@ -13,8 +13,6 @@ import networkx as nx
13
  from collections import Counter
14
  import json
15
  from datetime import datetime
16
- from transformers import AutoProcessor, AutoModelForVision2Seq
17
- from transformers.image_utils import load_image
18
 
19
  @dataclass
20
  class ChatMessage:
@@ -34,9 +32,11 @@ class XylariaChat:
34
  model="mistralai/Mistral-Nemo-Instruct-2407",
35
  token=self.hf_token
36
  )
 
 
 
37
 
38
  self.image_gen_api_url = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
39
- self.image_api_headers = {"Authorization": f"Bearer {self.hf_token}"}
40
 
41
  self.conversation_history = []
42
  self.persistent_memory = []
@@ -97,13 +97,6 @@ class XylariaChat:
97
 
98
  self.chat_history_file = "chat_history.json"
99
 
100
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
101
- self.vlm_processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
102
- self.vlm_model = AutoModelForVision2Seq.from_pretrained(
103
- "HuggingFaceTB/SmolVLM-Instruct",
104
- torch_dtype=torch.bfloat16,
105
- _attn_implementation="flash_attention_2" if self.device == "cuda" else "eager",
106
- ).to(self.device)
107
 
108
  def update_internal_state(self, emotion_deltas, cognitive_load_deltas, introspection_delta, engagement_delta):
109
  for emotion, delta in emotion_deltas.items():
@@ -408,44 +401,34 @@ class XylariaChat:
408
  print(f"Error resetting API client: {e}")
409
 
410
  return None
411
-
412
- def caption_image_vlm(self, image, user_input):
413
  try:
414
-
415
- if isinstance(image, str) and image.startswith('http'):
416
- image = load_image(image)
417
- elif isinstance(image, str) and os.path.isfile(image):
418
- image = Image.open(image)
419
- elif isinstance(image, str) and image.startswith('data:image'):
420
- image = Image.open(base64.b64decode(image.split(',')[1]))
421
  else:
422
- image = Image.fromarray(image)
423
-
424
- messages = [
425
- {
426
- "role": "user",
427
- "content": [
428
- {"type": "image"},
429
- {"type": "text", "text": user_input}
430
- ]
431
- },
432
- ]
433
-
434
- prompt = self.vlm_processor.apply_chat_template(messages, add_generation_prompt=True)
435
- inputs = self.vlm_processor(text=prompt, images=[image], return_tensors="pt")
436
- inputs = inputs.to(self.device)
437
-
438
- generated_ids = self.vlm_model.generate(**inputs, max_new_tokens=500)
439
- generated_texts = self.vlm_processor.batch_decode(
440
- generated_ids,
441
- skip_special_tokens=True,
442
  )
443
-
444
- return generated_texts[0].split("Assistant: ")[-1]
445
 
446
- except Exception as e:
447
- return f"Error captioning image with VLM: {str(e)}"
 
 
 
448
 
 
 
 
449
  def generate_image(self, prompt):
450
  try:
451
  payload = {"inputs": prompt}
@@ -501,11 +484,8 @@ class XylariaChat:
501
  messages.append(msg)
502
 
503
  if image:
504
- image_caption = self.caption_image_vlm(image, user_input)
505
- messages.append(ChatMessage(
506
- role="user",
507
- content=image_caption
508
- ).to_dict())
509
 
510
  messages.append(ChatMessage(
511
  role="user",
 
13
  from collections import Counter
14
  import json
15
  from datetime import datetime
 
 
16
 
17
  @dataclass
18
  class ChatMessage:
 
32
  model="mistralai/Mistral-Nemo-Instruct-2407",
33
  token=self.hf_token
34
  )
35
+
36
+ self.image_api_url = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large"
37
+ self.image_api_headers = {"Authorization": f"Bearer {self.hf_token}"}
38
 
39
  self.image_gen_api_url = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
 
40
 
41
  self.conversation_history = []
42
  self.persistent_memory = []
 
97
 
98
  self.chat_history_file = "chat_history.json"
99
 
 
 
 
 
 
 
 
100
 
101
  def update_internal_state(self, emotion_deltas, cognitive_load_deltas, introspection_delta, engagement_delta):
102
  for emotion, delta in emotion_deltas.items():
 
401
  print(f"Error resetting API client: {e}")
402
 
403
  return None
404
+
405
+ def caption_image(self, image):
406
  try:
407
+ if isinstance(image, str) and os.path.isfile(image):
408
+ with open(image, "rb") as f:
409
+ data = f.read()
410
+ elif isinstance(image, str):
411
+ if image.startswith('data:image'):
412
+ image = image.split(',')[1]
413
+ data = base64.b64decode(image)
414
  else:
415
+ data = image.read()
416
+
417
+ response = requests.post(
418
+ self.image_api_url,
419
+ headers=self.image_api_headers,
420
+ data=data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  )
 
 
422
 
423
+ if response.status_code == 200:
424
+ caption = response.json()[0].get('generated_text', 'No caption generated')
425
+ return caption
426
+ else:
427
+ return f"Error captioning image: {response.status_code} - {response.text}"
428
 
429
+ except Exception as e:
430
+ return f"Error processing image: {str(e)}"
431
+
432
  def generate_image(self, prompt):
433
  try:
434
  payload = {"inputs": prompt}
 
484
  messages.append(msg)
485
 
486
  if image:
487
+ image_caption = self.caption_image(image)
488
+ user_input = f"description of an image: {image_caption}\n\nUser's message about it: {user_input}"
 
 
 
489
 
490
  messages.append(ChatMessage(
491
  role="user",