Edit model card

The model was trained for 11 epochs, achieving a loss of 1.068 or something on a fillterd version of the dataset with 39.4k samples, use this code for best and any interference:

import gradio as gr
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch

# Load the fine-tuned model and tokenizer from the given path or name
model_name = "Pankaj8922/Instr-GPT2-MM"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

# Ensure the pad token is set to the EOS token (if not already set)
tokenizer.pad_token = tokenizer.eos_token

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def generate_response_live(history, max_length=1024, max_new_tokens=512, temperature=1.0, top_k=70, top_p=0.998):
    inputs = tokenizer(history, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
    input_ids = inputs.input_ids.to(device)  # Move input to GPU if available

    generated_text = ""
    model.eval()

    with torch.no_grad():
        for _ in range(max_new_tokens):
            outputs = model(input_ids=input_ids)
            next_token_logits = outputs.logits[:, -1, :]
            next_token_logits = next_token_logits / temperature

            filtered_logits = torch.topk(next_token_logits, top_k)[0]
            next_token_logits = torch.where(
                next_token_logits < filtered_logits[:, [-1]],
                torch.full_like(next_token_logits, float('-inf')),
                next_token_logits
            )
            next_token_probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
            next_token_id = torch.multinomial(next_token_probs, num_samples=1)

            input_ids = torch.cat([input_ids, next_token_id], dim=-1)

            next_token = tokenizer.decode(next_token_id[0])
            generated_text += next_token

            # Print the next token to show real-time generation
            print(next_token, end='', flush=True)

            # Check if "<end>" is in the generated text
            if "<end>" in generated_text:
                break

    print()  # Print a newline after the generation is complete
    return generated_text.split("<end>")[0].strip()

def chat(user_input, history, single_input_mode):
    if not user_input.strip():
        return "", history  # Return without updating if input is empty
    
    if single_input_mode:
        prompt = f"USER: {user_input} <end>"
        ai_response = generate_response_live(prompt)
        display_text = f"USER: {user_input}\n {ai_response}"
        return display_text, ""
    else:
        prompt = f"USER: {user_input} <end>"
        history += prompt
        ai_response = generate_response_live(history)
        history += f"{ai_response}"
        
        # Format the display text
        display_text = history.replace("<end>", "").replace("USER:", "\nUSER: ").replace("AI:", "\nAI: ")
        return display_text, history

def reset_history():
    return "", ""

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(scale=6):
            chat_box = gr.Textbox(label="Conversation", placeholder="Chat will appear here...", interactive=False, lines=20)
        with gr.Column(scale=4):
            user_input = gr.Textbox(label="Your Input", placeholder="Type your message here...")
            send_button = gr.Button("Send")
            clear_button = gr.Button("Clear Chat")
            single_input_checkbox = gr.Checkbox(label="Single Input Mode")
            loading_spinner = gr.HTML("<div style='text-align:center;'><img src='https://i.gifer.com/ZZ5H.gif' width='50px' style='display:none;' id='loading-spinner'></div>")
    
    history = gr.State("")

    def show_loading():
        return gr.HTML.update(value="<div style='text-align:center;'><img src='https://i.gifer.com/ZZ5H.gif' width='50px' style='display:block;' id='loading-spinner'></div>")
    
    def hide_loading():
        return gr.HTML.update(value="<div style='text-align:center;'><img src='https://i.gifer.com/ZZ5H.gif' width='50px' style='display:none;' id='loading-spinner'></div>")
    
    send_button.click(show_loading, outputs=loading_spinner)
    send_button.click(chat, inputs=[user_input, history, single_input_checkbox], outputs=[chat_box, history])
    send_button.click(hide_loading, outputs=loading_spinner)
    send_button.click(fn=lambda: "", outputs=user_input)  # Clear the input box after sending

    clear_button.click(reset_history, outputs=[chat_box, history])

demo.launch()
Downloads last month
24
Inference Examples
Inference API (serverless) is not available, repository is disabled.

Dataset used to train Pankaj8922/Instr-GPT2-MM