Adityak204's picture
Initial commit
70a0a5b
import torch
import torch.nn as nn
import torch.nn.functional as F
class LlamaRotaryEmbedding(nn.Module):
def __init__(
self,
dim: int = 64, # Dimension per attention head
max_seq_len: int = 2048, # Maximum sequence length
base: int = 10000, # Base for the angle calculations
device: str = None,
):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
# Create cache for position frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
# Create position sequence
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_tables(self, x: torch.Tensor, seq_len: int):
# Return early if cache is valid
if seq_len <= self._seq_len_cached:
return
# Update cache size
self._seq_len_cached = seq_len
# Create position sequence
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
# Calculate position frequencies
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Calculate embeddings
emb = torch.cat((freqs, freqs), dim=-1)
self._cos_cached = emb.cos() # [None, None, :, :]
self._sin_cached = emb.sin() # [None, None, :, :]
def forward(
self, q: torch.Tensor, k: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
batch, num_heads, seq_len, head_dim = q.shape
# Update cos/sin tables if needed
self._update_cos_sin_tables(q, seq_len)
# Get cos and sin for current sequence
cos = (
self._cos_cached[:seq_len, :].unsqueeze(0).unsqueeze(0)
) # Shape: [1, 1, seq_len, dim]
sin = (
self._sin_cached[:seq_len, :].unsqueeze(0).unsqueeze(0)
) # Shape: [1, 1, seq_len, dim]
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Apply rotary embeddings to q and k
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)