fredaddy commited on
Commit
193dc22
·
verified ·
1 Parent(s): 831b5a4

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +14 -6
handler.py CHANGED
@@ -7,21 +7,29 @@ from qwen_vl_utils import process_vision_info
7
 
8
  class EndpointHandler:
9
  def __init__(self, path=""):
10
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
  # Load model and processor
13
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
14
- path, torch_dtype="auto", device_map="auto"
15
- )
16
  self.processor = AutoProcessor.from_pretrained(path)
17
 
18
  def __call__(self, data):
19
  # Extract image and text from the input data
20
- image_url = data.get("image", "")
21
- text_prompt = data.get("text", "")
 
 
 
22
 
23
  # Download and process the image
24
- image = Image.open(BytesIO(requests.get(image_url).content))
 
 
 
 
 
25
 
26
  # Prepare the input in the format expected by the model
27
  messages = [
 
7
 
8
  class EndpointHandler:
9
  def __init__(self, path=""):
10
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  # Load model and processor
13
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
14
+ path, torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, device_map="auto"
15
+ ).to(self.device)
16
  self.processor = AutoProcessor.from_pretrained(path)
17
 
18
  def __call__(self, data):
19
  # Extract image and text from the input data
20
+ image_url = data.get("inputs", {}).get("image", "")
21
+ text_prompt = data.get("inputs", {}).get("text", "")
22
+
23
+ if not image_url or not text_prompt:
24
+ return {"error": "Both 'image' and 'text' must be provided in the input data."}
25
 
26
  # Download and process the image
27
+ try:
28
+ response = requests.get(image_url)
29
+ response.raise_for_status()
30
+ image = Image.open(BytesIO(response.content)).convert("RGB")
31
+ except Exception as e:
32
+ return {"error": f"Failed to load image from URL: {e}"}
33
 
34
  # Prepare the input in the format expected by the model
35
  messages = [