The model was trained for 11 epochs, achieving a loss of 1.068 or something on a fillterd version of the dataset with 39.4k samples, use this code for best and any interference:
import gradio as gr
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
# Load the fine-tuned model and tokenizer from the given path or name
model_name = "Pankaj8922/Instr-GPT2-MM"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# Ensure the pad token is set to the EOS token (if not already set)
tokenizer.pad_token = tokenizer.eos_token
# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def generate_response_live(history, max_length=1024, max_new_tokens=512, temperature=1.0, top_k=70, top_p=0.998):
inputs = tokenizer(history, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
input_ids = inputs.input_ids.to(device) # Move input to GPU if available
generated_text = ""
model.eval()
with torch.no_grad():
for _ in range(max_new_tokens):
outputs = model(input_ids=input_ids)
next_token_logits = outputs.logits[:, -1, :]
next_token_logits = next_token_logits / temperature
filtered_logits = torch.topk(next_token_logits, top_k)[0]
next_token_logits = torch.where(
next_token_logits < filtered_logits[:, [-1]],
torch.full_like(next_token_logits, float('-inf')),
next_token_logits
)
next_token_probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
next_token_id = torch.multinomial(next_token_probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
next_token = tokenizer.decode(next_token_id[0])
generated_text += next_token
# Print the next token to show real-time generation
print(next_token, end='', flush=True)
# Check if "<end>" is in the generated text
if "<end>" in generated_text:
break
print() # Print a newline after the generation is complete
return generated_text.split("<end>")[0].strip()
def chat(user_input, history, single_input_mode):
if not user_input.strip():
return "", history # Return without updating if input is empty
if single_input_mode:
prompt = f"USER: {user_input} <end>"
ai_response = generate_response_live(prompt)
display_text = f"USER: {user_input}\n {ai_response}"
return display_text, ""
else:
prompt = f"USER: {user_input} <end>"
history += prompt
ai_response = generate_response_live(history)
history += f"{ai_response}"
# Format the display text
display_text = history.replace("<end>", "").replace("USER:", "\nUSER: ").replace("AI:", "\nAI: ")
return display_text, history
def reset_history():
return "", ""
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=6):
chat_box = gr.Textbox(label="Conversation", placeholder="Chat will appear here...", interactive=False, lines=20)
with gr.Column(scale=4):
user_input = gr.Textbox(label="Your Input", placeholder="Type your message here...")
send_button = gr.Button("Send")
clear_button = gr.Button("Clear Chat")
single_input_checkbox = gr.Checkbox(label="Single Input Mode")
loading_spinner = gr.HTML("<div style='text-align:center;'><img src='https://i.gifer.com/ZZ5H.gif' width='50px' style='display:none;' id='loading-spinner'></div>")
history = gr.State("")
def show_loading():
return gr.HTML.update(value="<div style='text-align:center;'><img src='https://i.gifer.com/ZZ5H.gif' width='50px' style='display:block;' id='loading-spinner'></div>")
def hide_loading():
return gr.HTML.update(value="<div style='text-align:center;'><img src='https://i.gifer.com/ZZ5H.gif' width='50px' style='display:none;' id='loading-spinner'></div>")
send_button.click(show_loading, outputs=loading_spinner)
send_button.click(chat, inputs=[user_input, history, single_input_checkbox], outputs=[chat_box, history])
send_button.click(hide_loading, outputs=loading_spinner)
send_button.click(fn=lambda: "", outputs=user_input) # Clear the input box after sending
clear_button.click(reset_history, outputs=[chat_box, history])
demo.launch()
- Downloads last month
- 24
Inference API (serverless) is not available, repository is disabled.