File size: 3,855 Bytes
70a0a5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import streamlit as st
import torch
import torch.nn as nn
from transformers import AutoTokenizer
import os
from dataclasses import dataclass
from huggingface_hub import hf_hub_download

from src.model import SmolLM


def greedy_decode(model, input_ids, max_length=100, tokenizer=None):
    current_ids = input_ids

    with torch.no_grad():
        for _ in range(max_length - current_ids.shape[1]):
            outputs = model(current_ids)
            last_token_logits = outputs[:, -1, :]
            next_token = torch.argmax(last_token_logits, dim=-1).unsqueeze(0)

            current_ids = torch.cat([current_ids, next_token], dim=1)

            if next_token.item() == tokenizer.eos_token_id:
                break

    return current_ids


def generate_prediction(model, prompt, max_length=100):
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
    tokenizer.pad_token = tokenizer.eos_token
    device = next(model.parameters()).device

    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    model.eval()
    with torch.no_grad():
        generated_ids = greedy_decode(
            model, input_ids, max_length=max_length, tokenizer=tokenizer
        )

    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return generated_text


def main():
    # Set page configuration
    st.set_page_config(page_title="SmolLM2-TextGen", page_icon="πŸ€–")

    # Title and description
    st.title("SmolLM2-TextGen πŸ€–")
    st.write("Generate text using the SmolLM2 language model")

    # Load the model (you'll need to replace this with your actual model loading logic)
    @st.cache_resource
    def load_model(config):
        model = SmolLM(config)
        return model

    # Try to load the model
    try:

        @dataclass
        class MainConfig:
            vocab_size: int = 49152
            emb_dim: int = 576
            intermediate_size: int = 1536
            num_layers: int = 30
            n_q_heads: int = 9
            n_kv_heads: int = 3
            max_seq_len: int = 1024
            dropout: float = 0.1
            rms_norm_eps: float = 1e-05
            init_std: float = 0.041666666666666664

        config = MainConfig()
        model = load_model(config)
        # load checkpoint
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # checkpoint_path = "/Users/aditya/Documents/self_learning/ERA V3/week 13/artifacts/m1/smolLM-v2.pth"
        model_repo = "Adityak204/SmolLM2-135-cosmopedia-10k"
        model_filename = "smolLM-v2.pth"
        checkpoint_path = hf_hub_download(repo_id=model_repo, filename=model_filename)
        checkpoint = torch.load(checkpoint_path, map_location=device)[
            "model_state_dict"
        ]
        model.load_state_dict(checkpoint)

    except Exception as e:
        st.error(f"Error loading model: {e}")
        return

    # Input prompt
    prompt = st.text_input(
        "Enter your prompt:", placeholder="Type a sentence to generate text..."
    )

    # Max length slider
    max_length = st.slider(
        "Maximum Generation Length", min_value=10, max_value=200, value=100, step=10
    )

    # Generate button
    if st.button("Generate Text"):
        if not prompt:
            st.warning("Please enter a prompt.")
            return

        # Show loading spinner
        with st.spinner("Generating text..."):
            try:
                # Generate text
                generated_text = generate_prediction(model, prompt, max_length)

                # Display generated text
                st.subheader("Generated Text:")
                st.write(generated_text)

            except Exception as e:
                st.error(f"An error occurred during text generation: {e}")


if __name__ == "__main__":
    main()