fredaddy commited on
Commit
8a6ef8a
·
verified ·
1 Parent(s): 250141d

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +62 -0
handler.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import requests
4
+ from io import BytesIO
5
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
6
+ from qwen_vl_utils import process_vision_info
7
+
8
+ class EndpointHandler:
9
+ def __init__(self, path=""):
10
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ # Load model and processor
13
+ self.model = Qwen2VLForConditionalGeneration.from_pretrained(
14
+ path, torch_dtype="auto", device_map="auto"
15
+ )
16
+ self.processor = AutoProcessor.from_pretrained(path)
17
+
18
+ def __call__(self, data):
19
+ # Extract image and text from the input data
20
+ image_url = data.get("image", "")
21
+ text_prompt = data.get("text", "")
22
+
23
+ # Download and process the image
24
+ image = Image.open(BytesIO(requests.get(image_url).content))
25
+
26
+ # Prepare the input in the format expected by the model
27
+ messages = [
28
+ {
29
+ "role": "user",
30
+ "content": [
31
+ {"type": "image", "image": image},
32
+ {"type": "text", "text": text_prompt},
33
+ ],
34
+ }
35
+ ]
36
+
37
+ # Process the input
38
+ text = self.processor.apply_chat_template(
39
+ messages, tokenize=False, add_generation_prompt=True
40
+ )
41
+ image_inputs, video_inputs = process_vision_info(messages)
42
+ inputs = self.processor(
43
+ text=[text],
44
+ images=image_inputs,
45
+ videos=video_inputs,
46
+ padding=True,
47
+ return_tensors="pt",
48
+ )
49
+
50
+ # Move inputs to the appropriate device
51
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
52
+
53
+ # Generate output
54
+ with torch.no_grad():
55
+ output_ids = self.model.generate(**inputs, max_new_tokens=128)
56
+
57
+ # Decode the output
58
+ output_text = self.processor.batch_decode(
59
+ output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
60
+ )[0]
61
+
62
+ return {"generated_text": output_text}