File size: 2,704 Bytes
8a6ef8a a4ca6f4 8a6ef8a 193dc22 8a6ef8a 193dc22 8a6ef8a a4ca6f4 193dc22 537f66a a4ca6f4 193dc22 537f66a a4ca6f4 193dc22 a4ca6f4 193dc22 a4ca6f4 537f66a 8a6ef8a 537f66a 8a6ef8a 537f66a 8a6ef8a 537f66a 8a6ef8a 537f66a 8a6ef8a 537f66a 8a6ef8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import torch
from PIL import Image
import base64
from io import BytesIO
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
class EndpointHandler:
def __init__(self, path=""):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model and processor
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
path, torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, device_map="auto"
).to(self.device)
self.processor = AutoProcessor.from_pretrained(path)
def __call__(self, data):
# Extract image and text from the input data
image_data = data.get("inputs", {}).get("image", "")
text_prompt = data.get("inputs", {}).get("text", "")
if not image_data or not text_prompt:
return {"error": "Both 'image' and 'text' must be provided in the input data."}
# Process the image data
try:
image_bytes = base64.b64decode(image_data)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
except Exception as e:
return {"error": f"Failed to process image data: {e}"}
# Prepare the input in the format expected by the model
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": text_prompt},
],
}
]
# Process the input
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# Move inputs to the appropriate device
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate output
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
max_new_tokens=2000, # Increased from 128 to 2000
num_return_sequences=1,
do_sample=True,
temperature=0.7,
top_p=0.95
)
# Decode the output
output_text = self.processor.batch_decode(
output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return {"generated_text": output_text} |