File size: 2,704 Bytes
8a6ef8a
 
a4ca6f4
8a6ef8a
 
 
 
 
 
193dc22
8a6ef8a
 
 
193dc22
 
8a6ef8a
 
 
 
a4ca6f4
193dc22
537f66a
a4ca6f4
193dc22
537f66a
a4ca6f4
193dc22
a4ca6f4
 
193dc22
a4ca6f4
537f66a
8a6ef8a
 
 
 
 
 
 
 
 
 
537f66a
8a6ef8a
 
 
 
 
 
 
 
 
 
 
 
537f66a
8a6ef8a
 
537f66a
8a6ef8a
 
537f66a
 
 
 
 
 
 
 
 
8a6ef8a
 
 
 
537f66a
8a6ef8a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from PIL import Image
import base64
from io import BytesIO
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

class EndpointHandler:
    def __init__(self, path=""):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Load model and processor
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            path, torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, device_map="auto"
        ).to(self.device)
        self.processor = AutoProcessor.from_pretrained(path)

    def __call__(self, data):
        # Extract image and text from the input data
        image_data = data.get("inputs", {}).get("image", "")
        text_prompt = data.get("inputs", {}).get("text", "")
    
        if not image_data or not text_prompt:
            return {"error": "Both 'image' and 'text' must be provided in the input data."}
    
        # Process the image data
        try:
            image_bytes = base64.b64decode(image_data)
            image = Image.open(BytesIO(image_bytes)).convert("RGB")
        except Exception as e:
            return {"error": f"Failed to process image data: {e}"}
    
        # Prepare the input in the format expected by the model
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": text_prompt},
                ],
            }
        ]
    
        # Process the input
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
    
        # Move inputs to the appropriate device
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
    
        # Generate output
        with torch.no_grad():
            output_ids = self.model.generate(
                **inputs,
                max_new_tokens=2000,  # Increased from 128 to 2000
                num_return_sequences=1,
                do_sample=True,
                temperature=0.7,
                top_p=0.95
            )
    
        # Decode the output
        output_text = self.processor.batch_decode(
            output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
    
        return {"generated_text": output_text}