Adityak204 commited on
Commit
70a0a5b
·
1 Parent(s): ed6a6f5

Initial commit

Browse files
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import AutoTokenizer
5
+ import os
6
+ from dataclasses import dataclass
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ from src.model import SmolLM
10
+
11
+
12
+ def greedy_decode(model, input_ids, max_length=100, tokenizer=None):
13
+ current_ids = input_ids
14
+
15
+ with torch.no_grad():
16
+ for _ in range(max_length - current_ids.shape[1]):
17
+ outputs = model(current_ids)
18
+ last_token_logits = outputs[:, -1, :]
19
+ next_token = torch.argmax(last_token_logits, dim=-1).unsqueeze(0)
20
+
21
+ current_ids = torch.cat([current_ids, next_token], dim=1)
22
+
23
+ if next_token.item() == tokenizer.eos_token_id:
24
+ break
25
+
26
+ return current_ids
27
+
28
+
29
+ def generate_prediction(model, prompt, max_length=100):
30
+ # Load tokenizer
31
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
32
+ tokenizer.pad_token = tokenizer.eos_token
33
+ device = next(model.parameters()).device
34
+
35
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
36
+
37
+ model.eval()
38
+ with torch.no_grad():
39
+ generated_ids = greedy_decode(
40
+ model, input_ids, max_length=max_length, tokenizer=tokenizer
41
+ )
42
+
43
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
44
+ return generated_text
45
+
46
+
47
+ def main():
48
+ # Set page configuration
49
+ st.set_page_config(page_title="SmolLM2-TextGen", page_icon="🤖")
50
+
51
+ # Title and description
52
+ st.title("SmolLM2-TextGen 🤖")
53
+ st.write("Generate text using the SmolLM2 language model")
54
+
55
+ # Load the model (you'll need to replace this with your actual model loading logic)
56
+ @st.cache_resource
57
+ def load_model(config):
58
+ model = SmolLM(config)
59
+ return model
60
+
61
+ # Try to load the model
62
+ try:
63
+
64
+ @dataclass
65
+ class MainConfig:
66
+ vocab_size: int = 49152
67
+ emb_dim: int = 576
68
+ intermediate_size: int = 1536
69
+ num_layers: int = 30
70
+ n_q_heads: int = 9
71
+ n_kv_heads: int = 3
72
+ max_seq_len: int = 1024
73
+ dropout: float = 0.1
74
+ rms_norm_eps: float = 1e-05
75
+ init_std: float = 0.041666666666666664
76
+
77
+ config = MainConfig()
78
+ model = load_model(config)
79
+ # load checkpoint
80
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+ # checkpoint_path = "/Users/aditya/Documents/self_learning/ERA V3/week 13/artifacts/m1/smolLM-v2.pth"
82
+ model_repo = "Adityak204/SmolLM2-135-cosmopedia-10k"
83
+ model_filename = "smolLM-v2.pth"
84
+ checkpoint_path = hf_hub_download(repo_id=model_repo, filename=model_filename)
85
+ checkpoint = torch.load(checkpoint_path, map_location=device)[
86
+ "model_state_dict"
87
+ ]
88
+ model.load_state_dict(checkpoint)
89
+
90
+ except Exception as e:
91
+ st.error(f"Error loading model: {e}")
92
+ return
93
+
94
+ # Input prompt
95
+ prompt = st.text_input(
96
+ "Enter your prompt:", placeholder="Type a sentence to generate text..."
97
+ )
98
+
99
+ # Max length slider
100
+ max_length = st.slider(
101
+ "Maximum Generation Length", min_value=10, max_value=200, value=100, step=10
102
+ )
103
+
104
+ # Generate button
105
+ if st.button("Generate Text"):
106
+ if not prompt:
107
+ st.warning("Please enter a prompt.")
108
+ return
109
+
110
+ # Show loading spinner
111
+ with st.spinner("Generating text..."):
112
+ try:
113
+ # Generate text
114
+ generated_text = generate_prediction(model, prompt, max_length)
115
+
116
+ # Display generated text
117
+ st.subheader("Generated Text:")
118
+ st.write(generated_text)
119
+
120
+ except Exception as e:
121
+ st.error(f"An error occurred during text generation: {e}")
122
+
123
+
124
+ if __name__ == "__main__":
125
+ main()
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (173 Bytes). View file
 
src/__pycache__/model.cpython-310.pyc ADDED
Binary file (5.78 kB). View file
 
src/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.73 kB). View file
 
src/model.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from src.utils import LlamaRotaryEmbedding, repeat_kv
6
+
7
+
8
+ class RMSNorm(nn.Module):
9
+ def __init__(self, dim, eps=1e-6):
10
+ super().__init__()
11
+ self.eps = eps
12
+ self.weight = nn.Parameter(torch.ones(dim))
13
+
14
+ def forward(self, x):
15
+ # Root Mean Square Layer Normalization
16
+ rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
17
+ return x * rms * self.weight
18
+
19
+
20
+ class Attention(nn.Module):
21
+ """Multi-head attention module with support for GQA (Grouped Query Attention)."""
22
+
23
+ def __init__(self, config):
24
+ super(Attention, self).__init__()
25
+ self.emb_dim = config.emb_dim
26
+ self.n_q_heads = config.n_q_heads
27
+ self.n_kv_heads = config.n_kv_heads
28
+ self.head_dim = self.emb_dim // self.n_q_heads
29
+ self.n_rep = self.n_q_heads // self.n_kv_heads
30
+
31
+ # Projections for Q, K, V & O
32
+ self.q_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=False)
33
+ self.k_proj = nn.Linear(
34
+ self.emb_dim, self.head_dim * self.n_kv_heads, bias=False
35
+ )
36
+ self.v_proj = nn.Linear(
37
+ self.emb_dim, self.head_dim * self.n_kv_heads, bias=False
38
+ )
39
+ self.o_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=False)
40
+
41
+ # Initialize rotary embeddings
42
+ self.rotary_embedding = LlamaRotaryEmbedding(
43
+ dim=self.head_dim, max_seq_len=config.max_seq_len
44
+ )
45
+
46
+ # Dropout layers
47
+ self.attn_dropout = nn.Dropout(config.dropout)
48
+ self.resid_dropout = nn.Dropout(config.dropout)
49
+
50
+ # Causal mask
51
+ self.register_buffer(
52
+ "mask",
53
+ torch.tril(torch.ones(config.max_seq_len, config.max_seq_len)).view(
54
+ 1, 1, config.max_seq_len, config.max_seq_len
55
+ ),
56
+ )
57
+
58
+ def forward(self, x):
59
+ B, T, C = x.size() # batch_size, seq_len, emb_dim
60
+
61
+ # Project Q, K, V
62
+ q = self.q_proj(x) # (B, T, emb_dim)
63
+ k = self.k_proj(x) # (B, T, n_kv_heads * head_dim)
64
+ v = self.v_proj(x) # (B, T, n_kv_heads * head_dim)
65
+
66
+ # Reshape Q, K, V
67
+ q = q.view(B, T, self.n_q_heads, self.head_dim) # (B, T, n_q_heads, head_dim)
68
+ k = k.view(B, T, self.n_kv_heads, self.head_dim) # (B, T, n_kv_heads, head_dim)
69
+ v = v.view(B, T, self.n_kv_heads, self.head_dim) # (B, T, n_kv_heads, head_dim)
70
+
71
+ # Reshape for attention computation
72
+ q = q.transpose(1, 2) # (B, n_q_heads, T, head_dim)
73
+ k = k.transpose(1, 2) # (B, n_kv_heads, T, head_dim)
74
+ v = v.transpose(1, 2) # (B, n_kv_heads, T, head_dim)
75
+
76
+ # Apply rotary embeddings
77
+ q, k = self.rotary_embedding(q, k)
78
+
79
+ # Repeat K and V for GQA
80
+ k = repeat_kv(k, self.n_rep) # (B, n_q_heads, T, head_dim)
81
+ v = repeat_kv(v, self.n_rep) # (B, n_q_heads, T, head_dim)
82
+
83
+ # Compute attention scores
84
+ scale = 1.0 / math.sqrt(self.head_dim)
85
+ att = (q @ k.transpose(-2, -1)) * scale # (B, n_q_heads, T, T)
86
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
87
+ att = F.softmax(att, dim=-1)
88
+ att = self.attn_dropout(att)
89
+
90
+ # Apply attention to values
91
+ y = att @ v # (B, n_q_heads, T, head_dim)
92
+
93
+ # Reshape and project output
94
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, emb_dim)
95
+ y = self.o_proj(y)
96
+ y = self.resid_dropout(y)
97
+
98
+ return y
99
+
100
+
101
+ class FeedForward(nn.Module):
102
+ """Feed-forward module with SiLU activation."""
103
+
104
+ def __init__(self, config):
105
+ super(FeedForward, self).__init__()
106
+ # Gate and up-projections project from hidden_size to intermediate_size
107
+ self.gate_proj = nn.Linear(config.emb_dim, config.intermediate_size, bias=False)
108
+ self.up_proj = nn.Linear(config.emb_dim, config.intermediate_size, bias=False)
109
+
110
+ # Down projection brings the dimension back to hidden_size
111
+ self.down_proj = nn.Linear(config.intermediate_size, config.emb_dim, bias=False)
112
+
113
+ # SiLU activation function
114
+ self.act_fn = F.silu
115
+
116
+ # Dropout layer
117
+ self.dropout = nn.Dropout(config.dropout)
118
+
119
+ def forward(self, x):
120
+ # Apply gate and up projections
121
+ gate_output = self.act_fn(self.gate_proj(x)) # SiLU activation
122
+ up_output = self.up_proj(x)
123
+
124
+ # Element-wise multiplication of gate and up projections
125
+ intermediate_output = gate_output * up_output
126
+
127
+ # Project back to hidden size
128
+ output = self.down_proj(intermediate_output)
129
+ output = self.dropout(output)
130
+
131
+ return output
132
+
133
+
134
+ class TransformerBlock(nn.Module):
135
+ """Transformer block with attention and feed-forward modules."""
136
+
137
+ def __init__(self, config):
138
+ super(TransformerBlock, self).__init__()
139
+ self.attention = Attention(config)
140
+ self.feed_forward = FeedForward(config)
141
+ self.input_layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps)
142
+ self.attention_layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps)
143
+
144
+ def forward(self, x):
145
+ x = x + self.attention(self.input_layernorm(x))
146
+ x = x + self.feed_forward(self.attention_layernorm(x))
147
+
148
+ return x
149
+
150
+
151
+ class SmolLM(nn.Module):
152
+ """Small language model with transformer blocks."""
153
+
154
+ def __init__(self, config):
155
+ super(SmolLM, self).__init__()
156
+ self.config = config
157
+ self.wte = nn.Embedding(config.vocab_size, config.emb_dim)
158
+ self.transformer_blocks = nn.ModuleList(
159
+ [TransformerBlock(config) for _ in range(config.num_layers)]
160
+ )
161
+
162
+ self.lm_head = nn.Linear(config.emb_dim, config.vocab_size, bias=False)
163
+ self.apply(self._init_weights)
164
+ self.layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps)
165
+
166
+ # weight sharing
167
+ self.lm_head.weight = self.wte.weight
168
+
169
+ def total_params(self):
170
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
171
+
172
+ def _init_weights(self, module):
173
+ if isinstance(module, (nn.Linear, nn.Embedding)):
174
+ module.weight.data.normal_(mean=0.0, std=self.config.init_std)
175
+ if isinstance(module, nn.Linear) and module.bias is not None:
176
+ module.bias.data.zero_()
177
+ elif isinstance(module, nn.LayerNorm):
178
+ module.bias.data.zero_()
179
+ module.weight.data.fill_(1.0)
180
+
181
+ def forward(self, x):
182
+ x = self.wte(x)
183
+ for block in self.transformer_blocks:
184
+ x = block(x)
185
+ x = self.layernorm(x)
186
+ logits = self.lm_head(x)
187
+ return logits
188
+
189
+
190
+ # @dataclass
191
+ # class Config:
192
+ # vocab_size: int = 49152
193
+ # emb_dim: int = 576
194
+ # intermediate_size: int = 1536
195
+ # num_layers: int = 10
196
+ # n_q_heads: int = 9
197
+ # n_kv_heads: int = 3
198
+ # max_seq_len: int = 8192
199
+ # dropout: float = 0.1
200
+ # rms_norm_eps: float = 1e-05
201
+ # init_std: float = 0.041666666666666664
src/utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class LlamaRotaryEmbedding(nn.Module):
7
+ def __init__(
8
+ self,
9
+ dim: int = 64, # Dimension per attention head
10
+ max_seq_len: int = 2048, # Maximum sequence length
11
+ base: int = 10000, # Base for the angle calculations
12
+ device: str = None,
13
+ ):
14
+ super().__init__()
15
+ self.dim = dim
16
+ self.max_seq_len = max_seq_len
17
+ self.base = base
18
+
19
+ # Create cache for position frequencies
20
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
21
+ self.register_buffer("inv_freq", inv_freq)
22
+
23
+ # Create position sequence
24
+ self._seq_len_cached = 0
25
+ self._cos_cached = None
26
+ self._sin_cached = None
27
+
28
+ def _update_cos_sin_tables(self, x: torch.Tensor, seq_len: int):
29
+ # Return early if cache is valid
30
+ if seq_len <= self._seq_len_cached:
31
+ return
32
+
33
+ # Update cache size
34
+ self._seq_len_cached = seq_len
35
+
36
+ # Create position sequence
37
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
38
+ # Calculate position frequencies
39
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
40
+
41
+ # Calculate embeddings
42
+ emb = torch.cat((freqs, freqs), dim=-1)
43
+ self._cos_cached = emb.cos() # [None, None, :, :]
44
+ self._sin_cached = emb.sin() # [None, None, :, :]
45
+
46
+ def forward(
47
+ self, q: torch.Tensor, k: torch.Tensor
48
+ ) -> tuple[torch.Tensor, torch.Tensor]:
49
+ batch, num_heads, seq_len, head_dim = q.shape
50
+
51
+ # Update cos/sin tables if needed
52
+ self._update_cos_sin_tables(q, seq_len)
53
+
54
+ # Get cos and sin for current sequence
55
+ cos = (
56
+ self._cos_cached[:seq_len, :].unsqueeze(0).unsqueeze(0)
57
+ ) # Shape: [1, 1, seq_len, dim]
58
+ sin = (
59
+ self._sin_cached[:seq_len, :].unsqueeze(0).unsqueeze(0)
60
+ ) # Shape: [1, 1, seq_len, dim]
61
+
62
+ def rotate_half(x):
63
+ """Rotates half the hidden dims of the input."""
64
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
65
+ return torch.cat((-x2, x1), dim=-1)
66
+
67
+ # Apply rotary embeddings to q and k
68
+ q_embed = (q * cos) + (rotate_half(q) * sin)
69
+ k_embed = (k * cos) + (rotate_half(k) * sin)
70
+
71
+ return q_embed, k_embed
72
+
73
+
74
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
75
+ """
76
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
77
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
78
+ """
79
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
80
+ if n_rep == 1:
81
+ return hidden_states
82
+ hidden_states = hidden_states[:, :, None, :, :].expand(
83
+ batch, num_key_value_heads, n_rep, slen, head_dim
84
+ )
85
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)