Alibrown commited on
Commit
2e485c9
·
verified ·
1 Parent(s): ec2ec98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -50
app.py CHANGED
@@ -1,25 +1,20 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
 
4
- # Define available models
5
- MODELS = {
6
- "Zephyr 7B Beta": "HuggingFaceH4/zephyr-7b-beta",
7
- "GPT-2": "gpt2",
8
- "GPT-2 Medium": "gpt2-medium",
9
- "DistilGPT-2": "distilgpt2",
10
- "German GPT-2": "german-nlp-group/german-gpt2",
11
- "German Wechsel GPT-2": "benjamin/gpt2-wechsel-german",
12
- "T5 Base": "t5-base",
13
- "T5 Large": "t5-large"
14
- }
15
-
16
- def create_inference_client(model_name):
17
- """Create an InferenceClient for the selected model."""
18
- try:
19
- return InferenceClient(model_name)
20
- except Exception as e:
21
- print(f"Error creating client for {model_name}: {e}")
22
- return None
23
 
24
  def respond(
25
  message,
@@ -30,45 +25,65 @@ def respond(
30
  temperature,
31
  top_p,
32
  ):
33
- """Generate response using selected model."""
34
- # Create client for selected model
35
- client = create_inference_client(MODELS[model_name])
36
 
37
- if not client:
38
- return "Error: Could not create model client."
39
 
40
- # Prepare chat history
41
- messages = [{"role": "system", "content": system_message}]
42
- for val in history:
43
- if val[0]:
44
- messages.append({"role": "user", "content": val[0]})
45
- if val[1]:
46
- messages.append({"role": "assistant", "content": val[1]})
47
- messages.append({"role": "user", "content": message})
48
 
49
- # Generate response
50
- try:
51
- response = ""
52
- for message in client.chat_completion(
53
- messages,
54
- max_tokens=max_tokens,
55
- stream=True,
56
- temperature=temperature,
57
- top_p=top_p,
58
- ):
59
- token = message.choices[0].delta.content or ""
60
- response += token
61
- yield response
62
- except Exception as e:
63
- yield f"Error during generation: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def create_chat_interface():
66
  """Create Gradio ChatInterface with model selection."""
67
  demo = gr.ChatInterface(
68
  respond,
69
  additional_inputs=[
70
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
71
- gr.Dropdown(list(MODELS.keys()), value="Zephyr 7B Beta", label="Select Model"),
 
 
 
 
 
 
 
72
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
73
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
74
  gr.Slider(
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
 
4
+ # Define all pipelines
5
+ def load_pipelines():
6
+ pipelines = {
7
+ "GPT-2 Original": pipeline("text-generation", model="gpt2"),
8
+ "GPT-2 Medium": pipeline("text-generation", model="gpt2-medium"),
9
+ "DistilGPT-2": pipeline("text-generation", model="distilgpt2"),
10
+ "German GPT-2": pipeline("text-generation", model="german-nlp-group/german-gpt2"),
11
+ "German Wechsel GPT-2": pipeline("text-generation", model="benjamin/gpt2-wechsel-german"),
12
+ "T5 Base": pipeline("text-generation", model="t5-base"),
13
+ "T5 Large": pipeline("text-generation", model="t5-large"),
14
+ "Text Classification": pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english"),
15
+ "Sentiment Analysis": pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
16
+ }
17
+ return pipelines
 
 
 
 
 
18
 
19
  def respond(
20
  message,
 
25
  temperature,
26
  top_p,
27
  ):
28
+ # Load pipelines
29
+ pipelines = load_pipelines()
30
+ pipe = pipelines.get(model_name)
31
 
32
+ if not pipe:
33
+ return "Error: Model not found."
34
 
35
+ # For text generation models
36
+ if model_name in ["GPT-2 Original", "GPT-2 Medium", "DistilGPT-2",
37
+ "German GPT-2", "German Wechsel GPT-2",
38
+ "T5 Base", "T5 Large"]:
39
+ # Prepare full prompt
40
+ full_history = ' '.join([f"{msg[0]} {msg[1] or ''}" for msg in history]) if history else ''
41
+ full_prompt = f"{system_message}\n{full_history}\nUser: {message}\nAssistant:"
 
42
 
43
+ try:
44
+ response = pipe(
45
+ full_prompt,
46
+ max_length=len(full_prompt) + max_tokens,
47
+ temperature=temperature,
48
+ top_p=top_p,
49
+ num_return_sequences=1
50
+ )[0]['generated_text']
51
+
52
+ # Extract just the new assistant response
53
+ assistant_response = response[len(full_prompt):].strip()
54
+ return assistant_response
55
+ except Exception as e:
56
+ return f"Generation error: {e}"
57
+
58
+ # For classification and sentiment models
59
+ elif model_name == "Text Classification":
60
+ try:
61
+ result = pipe(message)[0]
62
+ return f"Classification: {result['label']} (Confidence: {result['score']:.2f})"
63
+ except Exception as e:
64
+ return f"Classification error: {e}"
65
+
66
+ elif model_name == "Sentiment Analysis":
67
+ try:
68
+ result = pipe(message)[0]
69
+ return f"Sentiment: {result['label']} (Confidence: {result['score']:.2f})"
70
+ except Exception as e:
71
+ return f"Sentiment analysis error: {e}"
72
 
73
  def create_chat_interface():
74
  """Create Gradio ChatInterface with model selection."""
75
  demo = gr.ChatInterface(
76
  respond,
77
  additional_inputs=[
78
+ gr.Textbox(value="You are a helpful assistant.", label="System message"),
79
+ gr.Dropdown(
80
+ ["GPT-2 Original", "GPT-2 Medium", "DistilGPT-2",
81
+ "German GPT-2", "German Wechsel GPT-2",
82
+ "T5 Base", "T5 Large",
83
+ "Text Classification", "Sentiment Analysis"],
84
+ value="GPT-2 Original",
85
+ label="Select Model"
86
+ ),
87
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
88
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
89
  gr.Slider(