FreeAI / app.py
Alibrown's picture
Update app.py
2e485c9 verified
raw
history blame
3.89 kB
import gradio as gr
from transformers import pipeline
# Define all pipelines
def load_pipelines():
pipelines = {
"GPT-2 Original": pipeline("text-generation", model="gpt2"),
"GPT-2 Medium": pipeline("text-generation", model="gpt2-medium"),
"DistilGPT-2": pipeline("text-generation", model="distilgpt2"),
"German GPT-2": pipeline("text-generation", model="german-nlp-group/german-gpt2"),
"German Wechsel GPT-2": pipeline("text-generation", model="benjamin/gpt2-wechsel-german"),
"T5 Base": pipeline("text-generation", model="t5-base"),
"T5 Large": pipeline("text-generation", model="t5-large"),
"Text Classification": pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english"),
"Sentiment Analysis": pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
}
return pipelines
def respond(
message,
history: list[tuple[str, str]],
system_message,
model_name,
max_tokens,
temperature,
top_p,
):
# Load pipelines
pipelines = load_pipelines()
pipe = pipelines.get(model_name)
if not pipe:
return "Error: Model not found."
# For text generation models
if model_name in ["GPT-2 Original", "GPT-2 Medium", "DistilGPT-2",
"German GPT-2", "German Wechsel GPT-2",
"T5 Base", "T5 Large"]:
# Prepare full prompt
full_history = ' '.join([f"{msg[0]} {msg[1] or ''}" for msg in history]) if history else ''
full_prompt = f"{system_message}\n{full_history}\nUser: {message}\nAssistant:"
try:
response = pipe(
full_prompt,
max_length=len(full_prompt) + max_tokens,
temperature=temperature,
top_p=top_p,
num_return_sequences=1
)[0]['generated_text']
# Extract just the new assistant response
assistant_response = response[len(full_prompt):].strip()
return assistant_response
except Exception as e:
return f"Generation error: {e}"
# For classification and sentiment models
elif model_name == "Text Classification":
try:
result = pipe(message)[0]
return f"Classification: {result['label']} (Confidence: {result['score']:.2f})"
except Exception as e:
return f"Classification error: {e}"
elif model_name == "Sentiment Analysis":
try:
result = pipe(message)[0]
return f"Sentiment: {result['label']} (Confidence: {result['score']:.2f})"
except Exception as e:
return f"Sentiment analysis error: {e}"
def create_chat_interface():
"""Create Gradio ChatInterface with model selection."""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a helpful assistant.", label="System message"),
gr.Dropdown(
["GPT-2 Original", "GPT-2 Medium", "DistilGPT-2",
"German GPT-2", "German Wechsel GPT-2",
"T5 Base", "T5 Large",
"Text Classification", "Sentiment Analysis"],
value="GPT-2 Original",
label="Select Model"
),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
]
)
return demo
if __name__ == "__main__":
chat_interface = create_chat_interface()
chat_interface.launch(share=True)