Spaces:
No application file
No application file
import gradio as gr | |
from transformers import pipeline | |
# Define all pipelines | |
def load_pipelines(): | |
pipelines = { | |
"GPT-2 Original": pipeline("text-generation", model="gpt2"), | |
"GPT-2 Medium": pipeline("text-generation", model="gpt2-medium"), | |
"DistilGPT-2": pipeline("text-generation", model="distilgpt2"), | |
"German GPT-2": pipeline("text-generation", model="german-nlp-group/german-gpt2"), | |
"German Wechsel GPT-2": pipeline("text-generation", model="benjamin/gpt2-wechsel-german"), | |
"T5 Base": pipeline("text-generation", model="t5-base"), | |
"T5 Large": pipeline("text-generation", model="t5-large"), | |
"Text Classification": pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english"), | |
"Sentiment Analysis": pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment") | |
} | |
return pipelines | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
model_name, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
# Load pipelines | |
pipelines = load_pipelines() | |
pipe = pipelines.get(model_name) | |
if not pipe: | |
return "Error: Model not found." | |
# For text generation models | |
if model_name in ["GPT-2 Original", "GPT-2 Medium", "DistilGPT-2", | |
"German GPT-2", "German Wechsel GPT-2", | |
"T5 Base", "T5 Large"]: | |
# Prepare full prompt | |
full_history = ' '.join([f"{msg[0]} {msg[1] or ''}" for msg in history]) if history else '' | |
full_prompt = f"{system_message}\n{full_history}\nUser: {message}\nAssistant:" | |
try: | |
response = pipe( | |
full_prompt, | |
max_length=len(full_prompt) + max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
num_return_sequences=1 | |
)[0]['generated_text'] | |
# Extract just the new assistant response | |
assistant_response = response[len(full_prompt):].strip() | |
return assistant_response | |
except Exception as e: | |
return f"Generation error: {e}" | |
# For classification and sentiment models | |
elif model_name == "Text Classification": | |
try: | |
result = pipe(message)[0] | |
return f"Classification: {result['label']} (Confidence: {result['score']:.2f})" | |
except Exception as e: | |
return f"Classification error: {e}" | |
elif model_name == "Sentiment Analysis": | |
try: | |
result = pipe(message)[0] | |
return f"Sentiment: {result['label']} (Confidence: {result['score']:.2f})" | |
except Exception as e: | |
return f"Sentiment analysis error: {e}" | |
def create_chat_interface(): | |
"""Create Gradio ChatInterface with model selection.""" | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="You are a helpful assistant.", label="System message"), | |
gr.Dropdown( | |
["GPT-2 Original", "GPT-2 Medium", "DistilGPT-2", | |
"German GPT-2", "German Wechsel GPT-2", | |
"T5 Base", "T5 Large", | |
"Text Classification", "Sentiment Analysis"], | |
value="GPT-2 Original", | |
label="Select Model" | |
), | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)", | |
), | |
] | |
) | |
return demo | |
if __name__ == "__main__": | |
chat_interface = create_chat_interface() | |
chat_interface.launch(share=True) |