|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import math |
|
from src.utils import LlamaRotaryEmbedding, repeat_kv |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
def __init__(self, dim, eps=1e-6): |
|
super().__init__() |
|
self.eps = eps |
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
def forward(self, x): |
|
|
|
rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
return x * rms * self.weight |
|
|
|
|
|
class Attention(nn.Module): |
|
"""Multi-head attention module with support for GQA (Grouped Query Attention).""" |
|
|
|
def __init__(self, config): |
|
super(Attention, self).__init__() |
|
self.emb_dim = config.emb_dim |
|
self.n_q_heads = config.n_q_heads |
|
self.n_kv_heads = config.n_kv_heads |
|
self.head_dim = self.emb_dim // self.n_q_heads |
|
self.n_rep = self.n_q_heads // self.n_kv_heads |
|
|
|
|
|
self.q_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=False) |
|
self.k_proj = nn.Linear( |
|
self.emb_dim, self.head_dim * self.n_kv_heads, bias=False |
|
) |
|
self.v_proj = nn.Linear( |
|
self.emb_dim, self.head_dim * self.n_kv_heads, bias=False |
|
) |
|
self.o_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=False) |
|
|
|
|
|
self.rotary_embedding = LlamaRotaryEmbedding( |
|
dim=self.head_dim, max_seq_len=config.max_seq_len |
|
) |
|
|
|
|
|
self.attn_dropout = nn.Dropout(config.dropout) |
|
self.resid_dropout = nn.Dropout(config.dropout) |
|
|
|
|
|
self.register_buffer( |
|
"mask", |
|
torch.tril(torch.ones(config.max_seq_len, config.max_seq_len)).view( |
|
1, 1, config.max_seq_len, config.max_seq_len |
|
), |
|
) |
|
|
|
def forward(self, x): |
|
B, T, C = x.size() |
|
|
|
|
|
q = self.q_proj(x) |
|
k = self.k_proj(x) |
|
v = self.v_proj(x) |
|
|
|
|
|
q = q.view(B, T, self.n_q_heads, self.head_dim) |
|
k = k.view(B, T, self.n_kv_heads, self.head_dim) |
|
v = v.view(B, T, self.n_kv_heads, self.head_dim) |
|
|
|
|
|
q = q.transpose(1, 2) |
|
k = k.transpose(1, 2) |
|
v = v.transpose(1, 2) |
|
|
|
|
|
q, k = self.rotary_embedding(q, k) |
|
|
|
|
|
k = repeat_kv(k, self.n_rep) |
|
v = repeat_kv(v, self.n_rep) |
|
|
|
|
|
scale = 1.0 / math.sqrt(self.head_dim) |
|
att = (q @ k.transpose(-2, -1)) * scale |
|
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) |
|
att = F.softmax(att, dim=-1) |
|
att = self.attn_dropout(att) |
|
|
|
|
|
y = att @ v |
|
|
|
|
|
y = y.transpose(1, 2).contiguous().view(B, T, C) |
|
y = self.o_proj(y) |
|
y = self.resid_dropout(y) |
|
|
|
return y |
|
|
|
|
|
class FeedForward(nn.Module): |
|
"""Feed-forward module with SiLU activation.""" |
|
|
|
def __init__(self, config): |
|
super(FeedForward, self).__init__() |
|
|
|
self.gate_proj = nn.Linear(config.emb_dim, config.intermediate_size, bias=False) |
|
self.up_proj = nn.Linear(config.emb_dim, config.intermediate_size, bias=False) |
|
|
|
|
|
self.down_proj = nn.Linear(config.intermediate_size, config.emb_dim, bias=False) |
|
|
|
|
|
self.act_fn = F.silu |
|
|
|
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
def forward(self, x): |
|
|
|
gate_output = self.act_fn(self.gate_proj(x)) |
|
up_output = self.up_proj(x) |
|
|
|
|
|
intermediate_output = gate_output * up_output |
|
|
|
|
|
output = self.down_proj(intermediate_output) |
|
output = self.dropout(output) |
|
|
|
return output |
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
"""Transformer block with attention and feed-forward modules.""" |
|
|
|
def __init__(self, config): |
|
super(TransformerBlock, self).__init__() |
|
self.attention = Attention(config) |
|
self.feed_forward = FeedForward(config) |
|
self.input_layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps) |
|
self.attention_layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps) |
|
|
|
def forward(self, x): |
|
x = x + self.attention(self.input_layernorm(x)) |
|
x = x + self.feed_forward(self.attention_layernorm(x)) |
|
|
|
return x |
|
|
|
|
|
class SmolLM(nn.Module): |
|
"""Small language model with transformer blocks.""" |
|
|
|
def __init__(self, config): |
|
super(SmolLM, self).__init__() |
|
self.config = config |
|
self.wte = nn.Embedding(config.vocab_size, config.emb_dim) |
|
self.transformer_blocks = nn.ModuleList( |
|
[TransformerBlock(config) for _ in range(config.num_layers)] |
|
) |
|
|
|
self.lm_head = nn.Linear(config.emb_dim, config.vocab_size, bias=False) |
|
self.apply(self._init_weights) |
|
self.layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps) |
|
|
|
|
|
self.lm_head.weight = self.wte.weight |
|
|
|
def total_params(self): |
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
def _init_weights(self, module): |
|
if isinstance(module, (nn.Linear, nn.Embedding)): |
|
module.weight.data.normal_(mean=0.0, std=self.config.init_std) |
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
def forward(self, x): |
|
x = self.wte(x) |
|
for block in self.transformer_blocks: |
|
x = block(x) |
|
x = self.layernorm(x) |
|
logits = self.lm_head(x) |
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|