import gradio as gr from transformers import pipeline # Define all pipelines def load_pipelines(): pipelines = { "GPT-2 Original": pipeline("text-generation", model="gpt2"), "GPT-2 Medium": pipeline("text-generation", model="gpt2-medium"), "DistilGPT-2": pipeline("text-generation", model="distilgpt2"), "German GPT-2": pipeline("text-generation", model="german-nlp-group/german-gpt2"), "German Wechsel GPT-2": pipeline("text-generation", model="benjamin/gpt2-wechsel-german"), "T5 Base": pipeline("text-generation", model="t5-base"), "T5 Large": pipeline("text-generation", model="t5-large"), "Text Classification": pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english"), "Sentiment Analysis": pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment") } return pipelines def respond( message, history: list[tuple[str, str]], system_message, model_name, max_tokens, temperature, top_p, ): # Load pipelines pipelines = load_pipelines() pipe = pipelines.get(model_name) if not pipe: return "Error: Model not found." # For text generation models if model_name in ["GPT-2 Original", "GPT-2 Medium", "DistilGPT-2", "German GPT-2", "German Wechsel GPT-2", "T5 Base", "T5 Large"]: # Prepare full prompt full_history = ' '.join([f"{msg[0]} {msg[1] or ''}" for msg in history]) if history else '' full_prompt = f"{system_message}\n{full_history}\nUser: {message}\nAssistant:" try: response = pipe( full_prompt, max_length=len(full_prompt) + max_tokens, temperature=temperature, top_p=top_p, num_return_sequences=1 )[0]['generated_text'] # Extract just the new assistant response assistant_response = response[len(full_prompt):].strip() return assistant_response except Exception as e: return f"Generation error: {e}" # For classification and sentiment models elif model_name == "Text Classification": try: result = pipe(message)[0] return f"Classification: {result['label']} (Confidence: {result['score']:.2f})" except Exception as e: return f"Classification error: {e}" elif model_name == "Sentiment Analysis": try: result = pipe(message)[0] return f"Sentiment: {result['label']} (Confidence: {result['score']:.2f})" except Exception as e: return f"Sentiment analysis error: {e}" def create_chat_interface(): """Create Gradio ChatInterface with model selection.""" demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox(value="You are a helpful assistant.", label="System message"), gr.Dropdown( ["GPT-2 Original", "GPT-2 Medium", "DistilGPT-2", "German GPT-2", "German Wechsel GPT-2", "T5 Base", "T5 Large", "Text Classification", "Sentiment Analysis"], value="GPT-2 Original", label="Select Model" ), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", ), ] ) return demo if __name__ == "__main__": chat_interface = create_chat_interface() chat_interface.launch(share=True)