Nymbo commited on
Commit
98674ca
·
verified ·
1 Parent(s): 52ad57a

custom models

Browse files
Files changed (1) hide show
  1. app.py +64 -69
app.py CHANGED
@@ -21,7 +21,8 @@ def respond(
21
  temperature,
22
  top_p,
23
  frequency_penalty,
24
- seed
 
25
  ):
26
  """
27
  This function handles the chatbot response. It takes in:
@@ -33,6 +34,7 @@ def respond(
33
  - top_p: top-p (nucleus) sampling
34
  - frequency_penalty: penalize repeated tokens in the output
35
  - seed: a fixed seed for reproducibility; -1 will mean 'random'
 
36
  """
37
 
38
  print(f"Received message: {message}")
@@ -40,6 +42,7 @@ def respond(
40
  print(f"System message: {system_message}")
41
  print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}")
42
  print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}")
 
43
 
44
  # Convert seed to None if -1 (meaning random)
45
  if seed == -1:
@@ -62,26 +65,30 @@ def respond(
62
  # Append the latest user message
63
  messages.append({"role": "user", "content": message})
64
 
 
 
 
 
65
  # Start with an empty string to build the response as tokens stream in
66
  response = ""
67
  print("Sending request to OpenAI API.")
68
 
69
  # Make the streaming request to the HF Inference API via openai-like client
70
  for message_chunk in client.chat.completions.create(
71
- model="meta-llama/Llama-3.3-70B-Instruct", # You can update this to your specific model
72
  max_tokens=max_tokens,
73
- stream=True, # Stream the response
74
  temperature=temperature,
75
  top_p=top_p,
76
- frequency_penalty=frequency_penalty, # <-- NEW
77
- seed=seed, # <-- NEW
78
  messages=messages,
79
  ):
80
  # Extract the token text from the response chunk
81
  token_text = message_chunk.choices[0].delta.content
82
  print(f"Received token: {token_text}")
83
  response += token_text
84
- # As streaming progresses, yield partial output
85
  yield response
86
 
87
  print("Completed response generation.")
@@ -90,69 +97,57 @@ def respond(
90
  chatbot = gr.Chatbot(height=600)
91
  print("Chatbot interface created.")
92
 
93
- MODELS_LIST = [
94
- "meta-llama/Llama-3.1-8B-Instruct",
95
- "microsoft/Phi-3.5-mini-instruct",
96
- ]
97
-
98
- def filter_models(search_term):
99
- """
100
- Simple function to filter the placeholder model list based on the user's input
101
- """
102
- filtered_models = [m for m in MODELS_LIST if search_term.lower() in m.lower()]
103
- return gr.update(choices=filtered_models)
104
-
105
- # --------------------------------------
106
- # REBUILD THE INTERFACE USING BLOCKS
107
- # --------------------------------------
108
- print("Building Gradio interface with Blocks...")
109
-
110
- with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
111
- # Title
112
- gr.Markdown("# Serverless-TextGen-Hub")
113
-
114
- # Accordion: Parameters (sliders, etc.)
115
- with gr.Accordion("Parameters", open=True):
116
- system_message = gr.Textbox(value="", label="System message")
117
- max_tokens = gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max new tokens")
118
- temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
119
- top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P")
120
- frequency_penalty = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty")
121
- seed = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)")
122
-
123
- # Accordion: Featured Models (Below the parameters)
124
- with gr.Accordion("Featured Models", open=False):
125
- model_search = gr.Textbox(
126
- label="Filter Models",
127
- placeholder="Search for a featured model...",
128
- lines=1
129
- )
130
- model_radio = gr.Radio(
131
- label="Select a model below",
132
- value=MODELS_LIST[0], # default
133
- choices=MODELS_LIST,
134
- interactive=True
135
- )
136
- model_search.change(filter_models, inputs=model_search, outputs=model_radio)
137
-
138
- # The main ChatInterface
139
- chat_interface = gr.ChatInterface(
140
- fn=respond,
141
- additional_inputs=[
142
- system_message,
143
- max_tokens,
144
- temperature,
145
- top_p,
146
- frequency_penalty,
147
- seed
148
- ],
149
- fill_height=True,
150
- chatbot=chatbot,
151
- theme="Nymbo/Nymbo_Theme",
152
- title="Serverless-TextGen-Hub",
153
- description="A comprehensive UI for text generation using the HF Inference API."
154
- )
155
-
156
  print("Gradio interface initialized.")
157
 
158
  if __name__ == "__main__":
 
21
  temperature,
22
  top_p,
23
  frequency_penalty,
24
+ seed,
25
+ custom_model
26
  ):
27
  """
28
  This function handles the chatbot response. It takes in:
 
34
  - top_p: top-p (nucleus) sampling
35
  - frequency_penalty: penalize repeated tokens in the output
36
  - seed: a fixed seed for reproducibility; -1 will mean 'random'
37
+ - custom_model: the user-provided custom model name (if any)
38
  """
39
 
40
  print(f"Received message: {message}")
 
42
  print(f"System message: {system_message}")
43
  print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}")
44
  print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}")
45
+ print(f"Custom model: {custom_model}")
46
 
47
  # Convert seed to None if -1 (meaning random)
48
  if seed == -1:
 
65
  # Append the latest user message
66
  messages.append({"role": "user", "content": message})
67
 
68
+ # Determine which model to use: either custom_model or a default
69
+ model_to_use = custom_model.strip() if custom_model.strip() != "" else "meta-llama/Llama-3.3-70B-Instruct"
70
+ print(f"Model selected for inference: {model_to_use}")
71
+
72
  # Start with an empty string to build the response as tokens stream in
73
  response = ""
74
  print("Sending request to OpenAI API.")
75
 
76
  # Make the streaming request to the HF Inference API via openai-like client
77
  for message_chunk in client.chat.completions.create(
78
+ model=model_to_use, # Use either the user-provided custom model or default
79
  max_tokens=max_tokens,
80
+ stream=True, # Stream the response
81
  temperature=temperature,
82
  top_p=top_p,
83
+ frequency_penalty=frequency_penalty,
84
+ seed=seed,
85
  messages=messages,
86
  ):
87
  # Extract the token text from the response chunk
88
  token_text = message_chunk.choices[0].delta.content
89
  print(f"Received token: {token_text}")
90
  response += token_text
91
+ # Yield the partial response to Gradio so it can display in real-time
92
  yield response
93
 
94
  print("Completed response generation.")
 
97
  chatbot = gr.Chatbot(height=600)
98
  print("Chatbot interface created.")
99
 
100
+ # Create the Gradio ChatInterface
101
+ # We add two new sliders for Frequency Penalty, Seed, and now a new "Custom Model" text box.
102
+ demo = gr.ChatInterface(
103
+ fn=respond,
104
+ additional_inputs=[
105
+ gr.Textbox(value="", label="System message"),
106
+ gr.Slider(
107
+ minimum=1,
108
+ maximum=4096,
109
+ value=512,
110
+ step=1,
111
+ label="Max new tokens"
112
+ ),
113
+ gr.Slider(
114
+ minimum=0.1,
115
+ maximum=4.0,
116
+ value=0.7,
117
+ step=0.1,
118
+ label="Temperature"
119
+ ),
120
+ gr.Slider(
121
+ minimum=0.1,
122
+ maximum=1.0,
123
+ value=0.95,
124
+ step=0.05,
125
+ label="Top-P"
126
+ ),
127
+ gr.Slider(
128
+ minimum=-2.0,
129
+ maximum=2.0,
130
+ value=0.0,
131
+ step=0.1,
132
+ label="Frequency Penalty"
133
+ ),
134
+ gr.Slider(
135
+ minimum=-1,
136
+ maximum=65535, # Arbitrary upper limit for demonstration
137
+ value=-1,
138
+ step=1,
139
+ label="Seed (-1 for random)"
140
+ ),
141
+ gr.Textbox(
142
+ value="",
143
+ label="Custom Model",
144
+ info="(Optional) Provide a custom Hugging Face model path. This will override the default model if not empty."
145
+ ),
146
+ ],
147
+ fill_height=True,
148
+ chatbot=chatbot,
149
+ theme="Nymbo/Nymbo_Theme",
150
+ )
 
 
 
 
 
 
 
 
 
 
 
 
151
  print("Gradio interface initialized.")
152
 
153
  if __name__ == "__main__":