File size: 3,891 Bytes
4559b07
2e485c9
4559b07
2e485c9
 
 
 
 
 
 
 
 
 
 
 
 
 
4559b07
 
 
 
 
ec2ec98
4559b07
 
 
 
2e485c9
 
 
ec2ec98
2e485c9
 
4559b07
2e485c9
 
 
 
 
 
 
4559b07
2e485c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4559b07
ec2ec98
 
 
 
 
2e485c9
 
 
 
 
 
 
 
 
ec2ec98
 
 
 
 
 
 
 
 
 
 
 
4559b07
 
ec2ec98
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
from transformers import pipeline

# Define all pipelines
def load_pipelines():
    pipelines = {
        "GPT-2 Original": pipeline("text-generation", model="gpt2"),
        "GPT-2 Medium": pipeline("text-generation", model="gpt2-medium"),
        "DistilGPT-2": pipeline("text-generation", model="distilgpt2"),
        "German GPT-2": pipeline("text-generation", model="german-nlp-group/german-gpt2"),
        "German Wechsel GPT-2": pipeline("text-generation", model="benjamin/gpt2-wechsel-german"),
        "T5 Base": pipeline("text-generation", model="t5-base"),
        "T5 Large": pipeline("text-generation", model="t5-large"),
        "Text Classification": pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english"),
        "Sentiment Analysis": pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
    }
    return pipelines

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    model_name,
    max_tokens,
    temperature,
    top_p,
):
    # Load pipelines
    pipelines = load_pipelines()
    pipe = pipelines.get(model_name)
    
    if not pipe:
        return "Error: Model not found."

    # For text generation models
    if model_name in ["GPT-2 Original", "GPT-2 Medium", "DistilGPT-2", 
                      "German GPT-2", "German Wechsel GPT-2", 
                      "T5 Base", "T5 Large"]:
        # Prepare full prompt
        full_history = ' '.join([f"{msg[0]} {msg[1] or ''}" for msg in history]) if history else ''
        full_prompt = f"{system_message}\n{full_history}\nUser: {message}\nAssistant:"

        try:
            response = pipe(
                full_prompt, 
                max_length=len(full_prompt) + max_tokens,
                temperature=temperature,
                top_p=top_p,
                num_return_sequences=1
            )[0]['generated_text']
            
            # Extract just the new assistant response
            assistant_response = response[len(full_prompt):].strip()
            return assistant_response
        except Exception as e:
            return f"Generation error: {e}"
    
    # For classification and sentiment models
    elif model_name == "Text Classification":
        try:
            result = pipe(message)[0]
            return f"Classification: {result['label']} (Confidence: {result['score']:.2f})"
        except Exception as e:
            return f"Classification error: {e}"
    
    elif model_name == "Sentiment Analysis":
        try:
            result = pipe(message)[0]
            return f"Sentiment: {result['label']} (Confidence: {result['score']:.2f})"
        except Exception as e:
            return f"Sentiment analysis error: {e}"

def create_chat_interface():
    """Create Gradio ChatInterface with model selection."""
    demo = gr.ChatInterface(
        respond,
        additional_inputs=[
            gr.Textbox(value="You are a helpful assistant.", label="System message"),
            gr.Dropdown(
                ["GPT-2 Original", "GPT-2 Medium", "DistilGPT-2", 
                 "German GPT-2", "German Wechsel GPT-2", 
                 "T5 Base", "T5 Large", 
                 "Text Classification", "Sentiment Analysis"], 
                value="GPT-2 Original", 
                label="Select Model"
            ),
            gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
            gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
            gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=0.95,
                step=0.05,
                label="Top-p (nucleus sampling)",
            ),
        ]
    )
    return demo

if __name__ == "__main__":
    chat_interface = create_chat_interface()
    chat_interface.launch(share=True)