Qwen-VL-7B-2 / handler.py
fredaddy's picture
Update handler.py
537f66a verified
raw
history blame contribute delete
No virus
2.7 kB
import torch
from PIL import Image
import base64
from io import BytesIO
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
class EndpointHandler:
def __init__(self, path=""):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model and processor
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
path, torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, device_map="auto"
).to(self.device)
self.processor = AutoProcessor.from_pretrained(path)
def __call__(self, data):
# Extract image and text from the input data
image_data = data.get("inputs", {}).get("image", "")
text_prompt = data.get("inputs", {}).get("text", "")
if not image_data or not text_prompt:
return {"error": "Both 'image' and 'text' must be provided in the input data."}
# Process the image data
try:
image_bytes = base64.b64decode(image_data)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
except Exception as e:
return {"error": f"Failed to process image data: {e}"}
# Prepare the input in the format expected by the model
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": text_prompt},
],
}
]
# Process the input
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# Move inputs to the appropriate device
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate output
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
max_new_tokens=2000, # Increased from 128 to 2000
num_return_sequences=1,
do_sample=True,
temperature=0.7,
top_p=0.95
)
# Decode the output
output_text = self.processor.batch_decode(
output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return {"generated_text": output_text}