|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class LlamaRotaryEmbedding(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int = 64, |
|
max_seq_len: int = 2048, |
|
base: int = 10000, |
|
device: str = None, |
|
): |
|
super().__init__() |
|
self.dim = dim |
|
self.max_seq_len = max_seq_len |
|
self.base = base |
|
|
|
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
|
self.register_buffer("inv_freq", inv_freq) |
|
|
|
|
|
self._seq_len_cached = 0 |
|
self._cos_cached = None |
|
self._sin_cached = None |
|
|
|
def _update_cos_sin_tables(self, x: torch.Tensor, seq_len: int): |
|
|
|
if seq_len <= self._seq_len_cached: |
|
return |
|
|
|
|
|
self._seq_len_cached = seq_len |
|
|
|
|
|
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) |
|
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
self._cos_cached = emb.cos() |
|
self._sin_cached = emb.sin() |
|
|
|
def forward( |
|
self, q: torch.Tensor, k: torch.Tensor |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
batch, num_heads, seq_len, head_dim = q.shape |
|
|
|
|
|
self._update_cos_sin_tables(q, seq_len) |
|
|
|
|
|
cos = ( |
|
self._cos_cached[:seq_len, :].unsqueeze(0).unsqueeze(0) |
|
) |
|
sin = ( |
|
self._sin_cached[:seq_len, :].unsqueeze(0).unsqueeze(0) |
|
) |
|
|
|
def rotate_half(x): |
|
"""Rotates half the hidden dims of the input.""" |
|
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
|
return q_embed, k_embed |
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
""" |
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
""" |
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states[:, :, None, :, :].expand( |
|
batch, num_key_value_heads, n_rep, slen, head_dim |
|
) |
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|