|
|
|
|
|
|
|
|
|
|
|
from copy import deepcopy |
|
from typing import Dict, List |
|
from PIL import Image |
|
import io |
|
import subprocess |
|
import requests |
|
import json |
|
import base64 |
|
import gradio as gr |
|
import librosa |
|
|
|
|
|
IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp") |
|
VIDEO_EXTENSIONS = (".mp4", ".mkv", ".mov", ".avi", ".flv", ".wmv", ".webm", ".m4v") |
|
AUDIO_EXTENSIONS = (".mp3", ".wav", "flac", ".m4a") |
|
|
|
DEFAULT_SAMPLING_PARAMS = { |
|
"top_p": 0.8, |
|
"top_k": 100, |
|
"temperature": 0.7, |
|
"do_sample": True, |
|
"num_beams": 1, |
|
"repetition_penalty": 1.2, |
|
} |
|
MAX_NEW_TOKENS = 1024 |
|
|
|
|
|
|
|
def load_image_to_base64(image_path): |
|
"""Load image and convert to base64 string""" |
|
with Image.open(image_path) as img: |
|
if img.mode != 'RGB': |
|
img = img.convert('RGB') |
|
img_byte_arr = io.BytesIO() |
|
img.save(img_byte_arr, format='PNG') |
|
img_byte_arr = img_byte_arr.getvalue() |
|
return base64.b64encode(img_byte_arr).decode('utf-8') |
|
|
|
def wav_to_bytes_with_ffmpeg(wav_file_path): |
|
process = subprocess.Popen( |
|
['ffmpeg', '-i', wav_file_path, '-f', 'wav', '-'], |
|
stdout=subprocess.PIPE, |
|
stderr=subprocess.PIPE |
|
) |
|
out, _ = process.communicate() |
|
return base64.b64encode(out).decode('utf-8') |
|
|
|
def parse_sse_response(response): |
|
for line in response.iter_lines(): |
|
if line: |
|
line = line.decode('utf-8') |
|
if line.startswith('data: '): |
|
data = line[6:] |
|
if data == '[DONE]': |
|
break |
|
try: |
|
json_data = json.loads(data) |
|
yield json_data['text'] |
|
except json.JSONDecodeError: |
|
raise gr.Error(f"Failed to parse JSON: {data}") |
|
|
|
def history2messages(history: List[Dict]) -> List[Dict]: |
|
""" |
|
Transform gradio history to chat messages. |
|
""" |
|
messages = [] |
|
cur_message = dict() |
|
for item in history: |
|
if item["role"] == "assistant": |
|
if len(cur_message) > 0: |
|
messages.append(deepcopy(cur_message)) |
|
cur_message = dict() |
|
messages.append(deepcopy(item)) |
|
continue |
|
|
|
if "role" not in cur_message: |
|
cur_message["role"] = "user" |
|
if "content" not in cur_message: |
|
cur_message["content"] = dict() |
|
|
|
if "metadata" not in item: |
|
item["metadata"] = {"title": None} |
|
if item["metadata"]["title"] is None: |
|
cur_message["content"]["text"] = item["content"] |
|
elif item["metadata"]["title"] == "image": |
|
cur_message["content"]["image"] = load_image_to_base64(item["content"][0]) |
|
elif item["metadata"]["title"] == "audio": |
|
cur_message["content"]["audio"] = wav_to_bytes_with_ffmpeg(item["content"][0]) |
|
if len(cur_message) > 0: |
|
messages.append(cur_message) |
|
return messages |
|
|
|
def check_messages(history, message, audio): |
|
has_text = message["text"] and message["text"].strip() |
|
has_files = len(message["files"]) > 0 |
|
has_audio = audio is not None |
|
|
|
if not (has_text or has_files or has_audio): |
|
raise gr.Error("请输入文字或上传音频/图片后再发送。") |
|
|
|
audios = [] |
|
images = [] |
|
|
|
for file_msg in message["files"]: |
|
if file_msg.endswith(AUDIO_EXTENSIONS) or file_msg.endswith(VIDEO_EXTENSIONS): |
|
duration = librosa.get_duration(filename=file_msg) |
|
if duration > 30: |
|
raise gr.Error("音频时长不能超过30秒。") |
|
if duration == 0: |
|
raise gr.Error("音频时长不能为0秒。") |
|
audios.append(file_msg) |
|
elif file_msg.endswith(IMAGE_EXTENSIONS): |
|
images.append(file_msg) |
|
else: |
|
filename = file_msg.split("/")[-1] |
|
raise gr.Error(f"Unsupported file type: {filename}. It should be an image or audio file.") |
|
|
|
if len(audios) > 1: |
|
raise gr.Error("Please upload only one audio file.") |
|
|
|
if len(images) > 1: |
|
raise gr.Error("Please upload only one image file.") |
|
|
|
if audio is not None: |
|
if len(audios) > 0: |
|
raise gr.Error("Please upload only one audio file or record audio.") |
|
audios.append(audio) |
|
|
|
|
|
for image in images: |
|
history.append({"role": "user", "content": (image,), "metadata": {"title": "image"}}) |
|
|
|
for audio in audios: |
|
history.append({"role": "user", "content": (audio,), "metadata": {"title": "audio"}}) |
|
|
|
if message["text"]: |
|
history.append({"role": "user", "content": message["text"]}) |
|
|
|
return history, gr.MultimodalTextbox(value=None, interactive=False), None |
|
|
|
def bot( |
|
history: list, |
|
top_p: float, |
|
top_k: int, |
|
temperature: float, |
|
repetition_penalty: float, |
|
max_new_tokens: int = MAX_NEW_TOKENS, |
|
regenerate: bool = False, |
|
): |
|
|
|
if history and regenerate: |
|
history = history[:-1] |
|
|
|
if not history: |
|
return history |
|
|
|
msgs = history2messages(history) |
|
|
|
API_URL = "http://8.152.0.142:8000/v1/chat" |
|
|
|
payload = { |
|
"messages": msgs, |
|
"sampling_params": { |
|
"top_p": top_p, |
|
"top_k": top_k, |
|
"temperature": temperature, |
|
"repetition_penalty": repetition_penalty, |
|
"max_new_tokens": max_new_tokens, |
|
"num_beams": 3, |
|
} |
|
} |
|
|
|
response = requests.get( |
|
API_URL, |
|
json=payload, |
|
headers={'Accept': 'text/event-stream'}, |
|
stream=True |
|
) |
|
response_text = "" |
|
|
|
for text in parse_sse_response(response): |
|
response_text += text |
|
yield history + [{"role": "assistant", "content": response_text}] |
|
|
|
return response_text |
|
|
|
def change_state(state): |
|
return gr.update(visible=not state), not state |
|
|
|
def reset_user_input(): |
|
return gr.update(value="") |
|
|
|
if __name__ == "__main__": |
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown( |
|
f""" |
|
# 🪐 Chat with <a href="https://github.com/infinigence/Infini-Megrez-Omni">Megrez-3B-Omni</a> |
|
""" |
|
) |
|
chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, type="messages", height='48vh') |
|
|
|
sampling_params_group_hidden_state = gr.State(False) |
|
|
|
|
|
with gr.Row(equal_height=True): |
|
chat_input = gr.MultimodalTextbox( |
|
file_count="multiple", |
|
placeholder="Enter your prompt or upload image/audio here, then press ENTER...", |
|
show_label=False, |
|
scale=8, |
|
file_types=["image", "audio"], |
|
interactive=True, |
|
|
|
) |
|
with gr.Row(equal_height=True): |
|
audio_input = gr.Audio( |
|
sources=["microphone", "upload"], |
|
type="filepath", |
|
scale=1, |
|
max_length=30 |
|
) |
|
with gr.Row(equal_height=True): |
|
with gr.Column(scale=1, min_width=150): |
|
with gr.Row(equal_height=True): |
|
regenerate_btn = gr.Button("Regenerate", variant="primary") |
|
clear_btn = gr.ClearButton( |
|
[chat_input, audio_input, chatbot], |
|
) |
|
|
|
with gr.Row(): |
|
sampling_params_toggle_btn = gr.Button("Sampling Parameters") |
|
|
|
with gr.Group(visible=False) as sampling_params_group: |
|
with gr.Row(): |
|
temperature = gr.Slider( |
|
minimum=0, maximum=1.2, value=DEFAULT_SAMPLING_PARAMS["temperature"], label="Temperature" |
|
) |
|
repetition_penalty = gr.Slider( |
|
minimum=0, |
|
maximum=2, |
|
value=DEFAULT_SAMPLING_PARAMS["repetition_penalty"], |
|
label="Repetition Penalty", |
|
) |
|
|
|
with gr.Row(): |
|
top_p = gr.Slider(minimum=0, maximum=1, value=DEFAULT_SAMPLING_PARAMS["top_p"], label="Top-p") |
|
top_k = gr.Slider(minimum=0, maximum=1000, value=DEFAULT_SAMPLING_PARAMS["top_k"], label="Top-k") |
|
|
|
with gr.Row(): |
|
max_new_tokens = gr.Slider( |
|
minimum=1, |
|
maximum=MAX_NEW_TOKENS, |
|
value=MAX_NEW_TOKENS, |
|
label="Max New Tokens", |
|
interactive=True, |
|
) |
|
|
|
sampling_params_toggle_btn.click( |
|
change_state, |
|
sampling_params_group_hidden_state, |
|
[sampling_params_group, sampling_params_group_hidden_state], |
|
) |
|
|
|
chat_msg = chat_input.submit( |
|
check_messages, |
|
[chatbot, chat_input, audio_input], |
|
[chatbot, chat_input, audio_input], |
|
) |
|
|
|
bot_msg = chat_msg.then( |
|
bot, |
|
inputs=[chatbot, top_p, top_k, temperature, repetition_penalty, max_new_tokens], |
|
outputs=chatbot, |
|
api_name="bot_response", |
|
) |
|
|
|
bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input]) |
|
|
|
regenerate_btn.click( |
|
bot, |
|
inputs=[chatbot, top_p, top_k, temperature, repetition_penalty, max_new_tokens, gr.State(True)], |
|
outputs=chatbot, |
|
) |
|
|
|
demo.launch() |