Adityak204's picture
Initial commit
70a0a5b
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from src.utils import LlamaRotaryEmbedding, repeat_kv
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
# Root Mean Square Layer Normalization
rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return x * rms * self.weight
class Attention(nn.Module):
"""Multi-head attention module with support for GQA (Grouped Query Attention)."""
def __init__(self, config):
super(Attention, self).__init__()
self.emb_dim = config.emb_dim
self.n_q_heads = config.n_q_heads
self.n_kv_heads = config.n_kv_heads
self.head_dim = self.emb_dim // self.n_q_heads
self.n_rep = self.n_q_heads // self.n_kv_heads
# Projections for Q, K, V & O
self.q_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=False)
self.k_proj = nn.Linear(
self.emb_dim, self.head_dim * self.n_kv_heads, bias=False
)
self.v_proj = nn.Linear(
self.emb_dim, self.head_dim * self.n_kv_heads, bias=False
)
self.o_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=False)
# Initialize rotary embeddings
self.rotary_embedding = LlamaRotaryEmbedding(
dim=self.head_dim, max_seq_len=config.max_seq_len
)
# Dropout layers
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
# Causal mask
self.register_buffer(
"mask",
torch.tril(torch.ones(config.max_seq_len, config.max_seq_len)).view(
1, 1, config.max_seq_len, config.max_seq_len
),
)
def forward(self, x):
B, T, C = x.size() # batch_size, seq_len, emb_dim
# Project Q, K, V
q = self.q_proj(x) # (B, T, emb_dim)
k = self.k_proj(x) # (B, T, n_kv_heads * head_dim)
v = self.v_proj(x) # (B, T, n_kv_heads * head_dim)
# Reshape Q, K, V
q = q.view(B, T, self.n_q_heads, self.head_dim) # (B, T, n_q_heads, head_dim)
k = k.view(B, T, self.n_kv_heads, self.head_dim) # (B, T, n_kv_heads, head_dim)
v = v.view(B, T, self.n_kv_heads, self.head_dim) # (B, T, n_kv_heads, head_dim)
# Reshape for attention computation
q = q.transpose(1, 2) # (B, n_q_heads, T, head_dim)
k = k.transpose(1, 2) # (B, n_kv_heads, T, head_dim)
v = v.transpose(1, 2) # (B, n_kv_heads, T, head_dim)
# Apply rotary embeddings
q, k = self.rotary_embedding(q, k)
# Repeat K and V for GQA
k = repeat_kv(k, self.n_rep) # (B, n_q_heads, T, head_dim)
v = repeat_kv(v, self.n_rep) # (B, n_q_heads, T, head_dim)
# Compute attention scores
scale = 1.0 / math.sqrt(self.head_dim)
att = (q @ k.transpose(-2, -1)) * scale # (B, n_q_heads, T, T)
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
# Apply attention to values
y = att @ v # (B, n_q_heads, T, head_dim)
# Reshape and project output
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, emb_dim)
y = self.o_proj(y)
y = self.resid_dropout(y)
return y
class FeedForward(nn.Module):
"""Feed-forward module with SiLU activation."""
def __init__(self, config):
super(FeedForward, self).__init__()
# Gate and up-projections project from hidden_size to intermediate_size
self.gate_proj = nn.Linear(config.emb_dim, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.emb_dim, config.intermediate_size, bias=False)
# Down projection brings the dimension back to hidden_size
self.down_proj = nn.Linear(config.intermediate_size, config.emb_dim, bias=False)
# SiLU activation function
self.act_fn = F.silu
# Dropout layer
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
# Apply gate and up projections
gate_output = self.act_fn(self.gate_proj(x)) # SiLU activation
up_output = self.up_proj(x)
# Element-wise multiplication of gate and up projections
intermediate_output = gate_output * up_output
# Project back to hidden size
output = self.down_proj(intermediate_output)
output = self.dropout(output)
return output
class TransformerBlock(nn.Module):
"""Transformer block with attention and feed-forward modules."""
def __init__(self, config):
super(TransformerBlock, self).__init__()
self.attention = Attention(config)
self.feed_forward = FeedForward(config)
self.input_layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps)
self.attention_layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps)
def forward(self, x):
x = x + self.attention(self.input_layernorm(x))
x = x + self.feed_forward(self.attention_layernorm(x))
return x
class SmolLM(nn.Module):
"""Small language model with transformer blocks."""
def __init__(self, config):
super(SmolLM, self).__init__()
self.config = config
self.wte = nn.Embedding(config.vocab_size, config.emb_dim)
self.transformer_blocks = nn.ModuleList(
[TransformerBlock(config) for _ in range(config.num_layers)]
)
self.lm_head = nn.Linear(config.emb_dim, config.vocab_size, bias=False)
self.apply(self._init_weights)
self.layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps)
# weight sharing
self.lm_head.weight = self.wte.weight
def total_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, x):
x = self.wte(x)
for block in self.transformer_blocks:
x = block(x)
x = self.layernorm(x)
logits = self.lm_head(x)
return logits
# @dataclass
# class Config:
# vocab_size: int = 49152
# emb_dim: int = 576
# intermediate_size: int = 1536
# num_layers: int = 10
# n_q_heads: int = 9
# n_kv_heads: int = 3
# max_seq_len: int = 8192
# dropout: float = 0.1
# rms_norm_eps: float = 1e-05
# init_std: float = 0.041666666666666664