|
import torch |
|
from PIL import Image |
|
import base64 |
|
from io import BytesIO |
|
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor |
|
from qwen_vl_utils import process_vision_info |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
self.model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
path, torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, device_map="auto" |
|
).to(self.device) |
|
self.processor = AutoProcessor.from_pretrained(path) |
|
|
|
def __call__(self, data): |
|
|
|
image_data = data.get("inputs", {}).get("image", "") |
|
text_prompt = data.get("inputs", {}).get("text", "") |
|
|
|
if not image_data or not text_prompt: |
|
return {"error": "Both 'image' and 'text' must be provided in the input data."} |
|
|
|
|
|
try: |
|
image_bytes = base64.b64decode(image_data) |
|
image = Image.open(BytesIO(image_bytes)).convert("RGB") |
|
except Exception as e: |
|
return {"error": f"Failed to process image data: {e}"} |
|
|
|
|
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image", "image": image}, |
|
{"type": "text", "text": text_prompt}, |
|
], |
|
} |
|
] |
|
|
|
|
|
text = self.processor.apply_chat_template( |
|
messages, tokenize=False, add_generation_prompt=True |
|
) |
|
image_inputs, video_inputs = process_vision_info(messages) |
|
inputs = self.processor( |
|
text=[text], |
|
images=image_inputs, |
|
videos=video_inputs, |
|
padding=True, |
|
return_tensors="pt", |
|
) |
|
|
|
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
output_ids = self.model.generate( |
|
**inputs, |
|
max_new_tokens=2000, |
|
num_return_sequences=1, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.95 |
|
) |
|
|
|
|
|
output_text = self.processor.batch_decode( |
|
output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
)[0] |
|
|
|
return {"generated_text": output_text} |