Commit
·
70a0a5b
1
Parent(s):
ed6a6f5
Initial commit
Browse files- app.py +125 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/__pycache__/model.cpython-310.pyc +0 -0
- src/__pycache__/utils.cpython-310.pyc +0 -0
- src/model.py +201 -0
- src/utils.py +85 -0
app.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
import os
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
|
9 |
+
from src.model import SmolLM
|
10 |
+
|
11 |
+
|
12 |
+
def greedy_decode(model, input_ids, max_length=100, tokenizer=None):
|
13 |
+
current_ids = input_ids
|
14 |
+
|
15 |
+
with torch.no_grad():
|
16 |
+
for _ in range(max_length - current_ids.shape[1]):
|
17 |
+
outputs = model(current_ids)
|
18 |
+
last_token_logits = outputs[:, -1, :]
|
19 |
+
next_token = torch.argmax(last_token_logits, dim=-1).unsqueeze(0)
|
20 |
+
|
21 |
+
current_ids = torch.cat([current_ids, next_token], dim=1)
|
22 |
+
|
23 |
+
if next_token.item() == tokenizer.eos_token_id:
|
24 |
+
break
|
25 |
+
|
26 |
+
return current_ids
|
27 |
+
|
28 |
+
|
29 |
+
def generate_prediction(model, prompt, max_length=100):
|
30 |
+
# Load tokenizer
|
31 |
+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
32 |
+
tokenizer.pad_token = tokenizer.eos_token
|
33 |
+
device = next(model.parameters()).device
|
34 |
+
|
35 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
36 |
+
|
37 |
+
model.eval()
|
38 |
+
with torch.no_grad():
|
39 |
+
generated_ids = greedy_decode(
|
40 |
+
model, input_ids, max_length=max_length, tokenizer=tokenizer
|
41 |
+
)
|
42 |
+
|
43 |
+
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
44 |
+
return generated_text
|
45 |
+
|
46 |
+
|
47 |
+
def main():
|
48 |
+
# Set page configuration
|
49 |
+
st.set_page_config(page_title="SmolLM2-TextGen", page_icon="🤖")
|
50 |
+
|
51 |
+
# Title and description
|
52 |
+
st.title("SmolLM2-TextGen 🤖")
|
53 |
+
st.write("Generate text using the SmolLM2 language model")
|
54 |
+
|
55 |
+
# Load the model (you'll need to replace this with your actual model loading logic)
|
56 |
+
@st.cache_resource
|
57 |
+
def load_model(config):
|
58 |
+
model = SmolLM(config)
|
59 |
+
return model
|
60 |
+
|
61 |
+
# Try to load the model
|
62 |
+
try:
|
63 |
+
|
64 |
+
@dataclass
|
65 |
+
class MainConfig:
|
66 |
+
vocab_size: int = 49152
|
67 |
+
emb_dim: int = 576
|
68 |
+
intermediate_size: int = 1536
|
69 |
+
num_layers: int = 30
|
70 |
+
n_q_heads: int = 9
|
71 |
+
n_kv_heads: int = 3
|
72 |
+
max_seq_len: int = 1024
|
73 |
+
dropout: float = 0.1
|
74 |
+
rms_norm_eps: float = 1e-05
|
75 |
+
init_std: float = 0.041666666666666664
|
76 |
+
|
77 |
+
config = MainConfig()
|
78 |
+
model = load_model(config)
|
79 |
+
# load checkpoint
|
80 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
81 |
+
# checkpoint_path = "/Users/aditya/Documents/self_learning/ERA V3/week 13/artifacts/m1/smolLM-v2.pth"
|
82 |
+
model_repo = "Adityak204/SmolLM2-135-cosmopedia-10k"
|
83 |
+
model_filename = "smolLM-v2.pth"
|
84 |
+
checkpoint_path = hf_hub_download(repo_id=model_repo, filename=model_filename)
|
85 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)[
|
86 |
+
"model_state_dict"
|
87 |
+
]
|
88 |
+
model.load_state_dict(checkpoint)
|
89 |
+
|
90 |
+
except Exception as e:
|
91 |
+
st.error(f"Error loading model: {e}")
|
92 |
+
return
|
93 |
+
|
94 |
+
# Input prompt
|
95 |
+
prompt = st.text_input(
|
96 |
+
"Enter your prompt:", placeholder="Type a sentence to generate text..."
|
97 |
+
)
|
98 |
+
|
99 |
+
# Max length slider
|
100 |
+
max_length = st.slider(
|
101 |
+
"Maximum Generation Length", min_value=10, max_value=200, value=100, step=10
|
102 |
+
)
|
103 |
+
|
104 |
+
# Generate button
|
105 |
+
if st.button("Generate Text"):
|
106 |
+
if not prompt:
|
107 |
+
st.warning("Please enter a prompt.")
|
108 |
+
return
|
109 |
+
|
110 |
+
# Show loading spinner
|
111 |
+
with st.spinner("Generating text..."):
|
112 |
+
try:
|
113 |
+
# Generate text
|
114 |
+
generated_text = generate_prediction(model, prompt, max_length)
|
115 |
+
|
116 |
+
# Display generated text
|
117 |
+
st.subheader("Generated Text:")
|
118 |
+
st.write(generated_text)
|
119 |
+
|
120 |
+
except Exception as e:
|
121 |
+
st.error(f"An error occurred during text generation: {e}")
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
main()
|
src/__init__.py
ADDED
File without changes
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (173 Bytes). View file
|
|
src/__pycache__/model.cpython-310.pyc
ADDED
Binary file (5.78 kB). View file
|
|
src/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (2.73 kB). View file
|
|
src/model.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
from src.utils import LlamaRotaryEmbedding, repeat_kv
|
6 |
+
|
7 |
+
|
8 |
+
class RMSNorm(nn.Module):
|
9 |
+
def __init__(self, dim, eps=1e-6):
|
10 |
+
super().__init__()
|
11 |
+
self.eps = eps
|
12 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
# Root Mean Square Layer Normalization
|
16 |
+
rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
17 |
+
return x * rms * self.weight
|
18 |
+
|
19 |
+
|
20 |
+
class Attention(nn.Module):
|
21 |
+
"""Multi-head attention module with support for GQA (Grouped Query Attention)."""
|
22 |
+
|
23 |
+
def __init__(self, config):
|
24 |
+
super(Attention, self).__init__()
|
25 |
+
self.emb_dim = config.emb_dim
|
26 |
+
self.n_q_heads = config.n_q_heads
|
27 |
+
self.n_kv_heads = config.n_kv_heads
|
28 |
+
self.head_dim = self.emb_dim // self.n_q_heads
|
29 |
+
self.n_rep = self.n_q_heads // self.n_kv_heads
|
30 |
+
|
31 |
+
# Projections for Q, K, V & O
|
32 |
+
self.q_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=False)
|
33 |
+
self.k_proj = nn.Linear(
|
34 |
+
self.emb_dim, self.head_dim * self.n_kv_heads, bias=False
|
35 |
+
)
|
36 |
+
self.v_proj = nn.Linear(
|
37 |
+
self.emb_dim, self.head_dim * self.n_kv_heads, bias=False
|
38 |
+
)
|
39 |
+
self.o_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=False)
|
40 |
+
|
41 |
+
# Initialize rotary embeddings
|
42 |
+
self.rotary_embedding = LlamaRotaryEmbedding(
|
43 |
+
dim=self.head_dim, max_seq_len=config.max_seq_len
|
44 |
+
)
|
45 |
+
|
46 |
+
# Dropout layers
|
47 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
48 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
49 |
+
|
50 |
+
# Causal mask
|
51 |
+
self.register_buffer(
|
52 |
+
"mask",
|
53 |
+
torch.tril(torch.ones(config.max_seq_len, config.max_seq_len)).view(
|
54 |
+
1, 1, config.max_seq_len, config.max_seq_len
|
55 |
+
),
|
56 |
+
)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
B, T, C = x.size() # batch_size, seq_len, emb_dim
|
60 |
+
|
61 |
+
# Project Q, K, V
|
62 |
+
q = self.q_proj(x) # (B, T, emb_dim)
|
63 |
+
k = self.k_proj(x) # (B, T, n_kv_heads * head_dim)
|
64 |
+
v = self.v_proj(x) # (B, T, n_kv_heads * head_dim)
|
65 |
+
|
66 |
+
# Reshape Q, K, V
|
67 |
+
q = q.view(B, T, self.n_q_heads, self.head_dim) # (B, T, n_q_heads, head_dim)
|
68 |
+
k = k.view(B, T, self.n_kv_heads, self.head_dim) # (B, T, n_kv_heads, head_dim)
|
69 |
+
v = v.view(B, T, self.n_kv_heads, self.head_dim) # (B, T, n_kv_heads, head_dim)
|
70 |
+
|
71 |
+
# Reshape for attention computation
|
72 |
+
q = q.transpose(1, 2) # (B, n_q_heads, T, head_dim)
|
73 |
+
k = k.transpose(1, 2) # (B, n_kv_heads, T, head_dim)
|
74 |
+
v = v.transpose(1, 2) # (B, n_kv_heads, T, head_dim)
|
75 |
+
|
76 |
+
# Apply rotary embeddings
|
77 |
+
q, k = self.rotary_embedding(q, k)
|
78 |
+
|
79 |
+
# Repeat K and V for GQA
|
80 |
+
k = repeat_kv(k, self.n_rep) # (B, n_q_heads, T, head_dim)
|
81 |
+
v = repeat_kv(v, self.n_rep) # (B, n_q_heads, T, head_dim)
|
82 |
+
|
83 |
+
# Compute attention scores
|
84 |
+
scale = 1.0 / math.sqrt(self.head_dim)
|
85 |
+
att = (q @ k.transpose(-2, -1)) * scale # (B, n_q_heads, T, T)
|
86 |
+
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
|
87 |
+
att = F.softmax(att, dim=-1)
|
88 |
+
att = self.attn_dropout(att)
|
89 |
+
|
90 |
+
# Apply attention to values
|
91 |
+
y = att @ v # (B, n_q_heads, T, head_dim)
|
92 |
+
|
93 |
+
# Reshape and project output
|
94 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, emb_dim)
|
95 |
+
y = self.o_proj(y)
|
96 |
+
y = self.resid_dropout(y)
|
97 |
+
|
98 |
+
return y
|
99 |
+
|
100 |
+
|
101 |
+
class FeedForward(nn.Module):
|
102 |
+
"""Feed-forward module with SiLU activation."""
|
103 |
+
|
104 |
+
def __init__(self, config):
|
105 |
+
super(FeedForward, self).__init__()
|
106 |
+
# Gate and up-projections project from hidden_size to intermediate_size
|
107 |
+
self.gate_proj = nn.Linear(config.emb_dim, config.intermediate_size, bias=False)
|
108 |
+
self.up_proj = nn.Linear(config.emb_dim, config.intermediate_size, bias=False)
|
109 |
+
|
110 |
+
# Down projection brings the dimension back to hidden_size
|
111 |
+
self.down_proj = nn.Linear(config.intermediate_size, config.emb_dim, bias=False)
|
112 |
+
|
113 |
+
# SiLU activation function
|
114 |
+
self.act_fn = F.silu
|
115 |
+
|
116 |
+
# Dropout layer
|
117 |
+
self.dropout = nn.Dropout(config.dropout)
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
# Apply gate and up projections
|
121 |
+
gate_output = self.act_fn(self.gate_proj(x)) # SiLU activation
|
122 |
+
up_output = self.up_proj(x)
|
123 |
+
|
124 |
+
# Element-wise multiplication of gate and up projections
|
125 |
+
intermediate_output = gate_output * up_output
|
126 |
+
|
127 |
+
# Project back to hidden size
|
128 |
+
output = self.down_proj(intermediate_output)
|
129 |
+
output = self.dropout(output)
|
130 |
+
|
131 |
+
return output
|
132 |
+
|
133 |
+
|
134 |
+
class TransformerBlock(nn.Module):
|
135 |
+
"""Transformer block with attention and feed-forward modules."""
|
136 |
+
|
137 |
+
def __init__(self, config):
|
138 |
+
super(TransformerBlock, self).__init__()
|
139 |
+
self.attention = Attention(config)
|
140 |
+
self.feed_forward = FeedForward(config)
|
141 |
+
self.input_layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps)
|
142 |
+
self.attention_layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps)
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
x = x + self.attention(self.input_layernorm(x))
|
146 |
+
x = x + self.feed_forward(self.attention_layernorm(x))
|
147 |
+
|
148 |
+
return x
|
149 |
+
|
150 |
+
|
151 |
+
class SmolLM(nn.Module):
|
152 |
+
"""Small language model with transformer blocks."""
|
153 |
+
|
154 |
+
def __init__(self, config):
|
155 |
+
super(SmolLM, self).__init__()
|
156 |
+
self.config = config
|
157 |
+
self.wte = nn.Embedding(config.vocab_size, config.emb_dim)
|
158 |
+
self.transformer_blocks = nn.ModuleList(
|
159 |
+
[TransformerBlock(config) for _ in range(config.num_layers)]
|
160 |
+
)
|
161 |
+
|
162 |
+
self.lm_head = nn.Linear(config.emb_dim, config.vocab_size, bias=False)
|
163 |
+
self.apply(self._init_weights)
|
164 |
+
self.layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps)
|
165 |
+
|
166 |
+
# weight sharing
|
167 |
+
self.lm_head.weight = self.wte.weight
|
168 |
+
|
169 |
+
def total_params(self):
|
170 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
171 |
+
|
172 |
+
def _init_weights(self, module):
|
173 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
174 |
+
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
|
175 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
176 |
+
module.bias.data.zero_()
|
177 |
+
elif isinstance(module, nn.LayerNorm):
|
178 |
+
module.bias.data.zero_()
|
179 |
+
module.weight.data.fill_(1.0)
|
180 |
+
|
181 |
+
def forward(self, x):
|
182 |
+
x = self.wte(x)
|
183 |
+
for block in self.transformer_blocks:
|
184 |
+
x = block(x)
|
185 |
+
x = self.layernorm(x)
|
186 |
+
logits = self.lm_head(x)
|
187 |
+
return logits
|
188 |
+
|
189 |
+
|
190 |
+
# @dataclass
|
191 |
+
# class Config:
|
192 |
+
# vocab_size: int = 49152
|
193 |
+
# emb_dim: int = 576
|
194 |
+
# intermediate_size: int = 1536
|
195 |
+
# num_layers: int = 10
|
196 |
+
# n_q_heads: int = 9
|
197 |
+
# n_kv_heads: int = 3
|
198 |
+
# max_seq_len: int = 8192
|
199 |
+
# dropout: float = 0.1
|
200 |
+
# rms_norm_eps: float = 1e-05
|
201 |
+
# init_std: float = 0.041666666666666664
|
src/utils.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class LlamaRotaryEmbedding(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
dim: int = 64, # Dimension per attention head
|
10 |
+
max_seq_len: int = 2048, # Maximum sequence length
|
11 |
+
base: int = 10000, # Base for the angle calculations
|
12 |
+
device: str = None,
|
13 |
+
):
|
14 |
+
super().__init__()
|
15 |
+
self.dim = dim
|
16 |
+
self.max_seq_len = max_seq_len
|
17 |
+
self.base = base
|
18 |
+
|
19 |
+
# Create cache for position frequencies
|
20 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
21 |
+
self.register_buffer("inv_freq", inv_freq)
|
22 |
+
|
23 |
+
# Create position sequence
|
24 |
+
self._seq_len_cached = 0
|
25 |
+
self._cos_cached = None
|
26 |
+
self._sin_cached = None
|
27 |
+
|
28 |
+
def _update_cos_sin_tables(self, x: torch.Tensor, seq_len: int):
|
29 |
+
# Return early if cache is valid
|
30 |
+
if seq_len <= self._seq_len_cached:
|
31 |
+
return
|
32 |
+
|
33 |
+
# Update cache size
|
34 |
+
self._seq_len_cached = seq_len
|
35 |
+
|
36 |
+
# Create position sequence
|
37 |
+
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
38 |
+
# Calculate position frequencies
|
39 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
40 |
+
|
41 |
+
# Calculate embeddings
|
42 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
43 |
+
self._cos_cached = emb.cos() # [None, None, :, :]
|
44 |
+
self._sin_cached = emb.sin() # [None, None, :, :]
|
45 |
+
|
46 |
+
def forward(
|
47 |
+
self, q: torch.Tensor, k: torch.Tensor
|
48 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
49 |
+
batch, num_heads, seq_len, head_dim = q.shape
|
50 |
+
|
51 |
+
# Update cos/sin tables if needed
|
52 |
+
self._update_cos_sin_tables(q, seq_len)
|
53 |
+
|
54 |
+
# Get cos and sin for current sequence
|
55 |
+
cos = (
|
56 |
+
self._cos_cached[:seq_len, :].unsqueeze(0).unsqueeze(0)
|
57 |
+
) # Shape: [1, 1, seq_len, dim]
|
58 |
+
sin = (
|
59 |
+
self._sin_cached[:seq_len, :].unsqueeze(0).unsqueeze(0)
|
60 |
+
) # Shape: [1, 1, seq_len, dim]
|
61 |
+
|
62 |
+
def rotate_half(x):
|
63 |
+
"""Rotates half the hidden dims of the input."""
|
64 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
65 |
+
return torch.cat((-x2, x1), dim=-1)
|
66 |
+
|
67 |
+
# Apply rotary embeddings to q and k
|
68 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
69 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
70 |
+
|
71 |
+
return q_embed, k_embed
|
72 |
+
|
73 |
+
|
74 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
75 |
+
"""
|
76 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
77 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
78 |
+
"""
|
79 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
80 |
+
if n_rep == 1:
|
81 |
+
return hidden_states
|
82 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
83 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
84 |
+
)
|
85 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|