fredaddy commited on
Commit
537f66a
1 Parent(s): a4ca6f4

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +16 -9
handler.py CHANGED
@@ -19,17 +19,17 @@ class EndpointHandler:
19
  # Extract image and text from the input data
20
  image_data = data.get("inputs", {}).get("image", "")
21
  text_prompt = data.get("inputs", {}).get("text", "")
22
-
23
  if not image_data or not text_prompt:
24
  return {"error": "Both 'image' and 'text' must be provided in the input data."}
25
-
26
  # Process the image data
27
  try:
28
  image_bytes = base64.b64decode(image_data)
29
  image = Image.open(BytesIO(image_bytes)).convert("RGB")
30
  except Exception as e:
31
  return {"error": f"Failed to process image data: {e}"}
32
-
33
  # Prepare the input in the format expected by the model
34
  messages = [
35
  {
@@ -40,7 +40,7 @@ class EndpointHandler:
40
  ],
41
  }
42
  ]
43
-
44
  # Process the input
45
  text = self.processor.apply_chat_template(
46
  messages, tokenize=False, add_generation_prompt=True
@@ -53,17 +53,24 @@ class EndpointHandler:
53
  padding=True,
54
  return_tensors="pt",
55
  )
56
-
57
  # Move inputs to the appropriate device
58
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
59
-
60
  # Generate output
61
  with torch.no_grad():
62
- output_ids = self.model.generate(**inputs, max_new_tokens=128)
63
-
 
 
 
 
 
 
 
64
  # Decode the output
65
  output_text = self.processor.batch_decode(
66
  output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
67
  )[0]
68
-
69
  return {"generated_text": output_text}
 
19
  # Extract image and text from the input data
20
  image_data = data.get("inputs", {}).get("image", "")
21
  text_prompt = data.get("inputs", {}).get("text", "")
22
+
23
  if not image_data or not text_prompt:
24
  return {"error": "Both 'image' and 'text' must be provided in the input data."}
25
+
26
  # Process the image data
27
  try:
28
  image_bytes = base64.b64decode(image_data)
29
  image = Image.open(BytesIO(image_bytes)).convert("RGB")
30
  except Exception as e:
31
  return {"error": f"Failed to process image data: {e}"}
32
+
33
  # Prepare the input in the format expected by the model
34
  messages = [
35
  {
 
40
  ],
41
  }
42
  ]
43
+
44
  # Process the input
45
  text = self.processor.apply_chat_template(
46
  messages, tokenize=False, add_generation_prompt=True
 
53
  padding=True,
54
  return_tensors="pt",
55
  )
56
+
57
  # Move inputs to the appropriate device
58
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
59
+
60
  # Generate output
61
  with torch.no_grad():
62
+ output_ids = self.model.generate(
63
+ **inputs,
64
+ max_new_tokens=2000, # Increased from 128 to 2000
65
+ num_return_sequences=1,
66
+ do_sample=True,
67
+ temperature=0.7,
68
+ top_p=0.95
69
+ )
70
+
71
  # Decode the output
72
  output_text = self.processor.batch_decode(
73
  output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
74
  )[0]
75
+
76
  return {"generated_text": output_text}