Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
from torch.nn import Linear, Identity, Module | |
def default(v, d): | |
return v if exists(v) else d | |
def exists(v): | |
return v is not None | |
def heinsen_associative_scan_log(log_coeffs, log_values): | |
a_star = log_coeffs.cumsum(dim=1) | |
log_h0_plus_b_star = (log_values - a_star).logcumsumexp(dim=1) | |
log_h = a_star + log_h0_plus_b_star | |
return log_h.exp() | |
def log_g(x): | |
return torch.where(x >= 0, (F.relu(x) + 0.5).log(), -F.softplus(-x)) | |
class MinGRU(Module): | |
def __init__(self, dim, expansion_factor=1.): | |
super().__init__() | |
dim_inner = int(dim * expansion_factor) | |
# Combined transformation for hidden state and gate | |
self.to_hidden = Linear(dim, dim_inner, bias=False) | |
self.to_gate = Linear(dim,dim_inner,bias=False) | |
# Output projection (Identity if no expansion) | |
self.to_out = Linear(dim_inner, dim, bias=False) if expansion_factor != 1. else Identity() | |
def forward(self, x, prev_hidden=None, return_next_prev_hidden=False): | |
# Split combined transformation into hidden and gate components | |
hidden= self.to_hidden(x) | |
gate = self.to_gate(x) | |
# Convert to log space for numerical stability | |
log_coeffs = -F.softplus(gate) # log(1 - σ(gate)) | |
log_z = -F.softplus(-gate) # log(σ(gate)) | |
log_tilde_h = log_g(hidden) # log(g(hidden)) | |
log_values = log_z + log_tilde_h # log(z * h_tilde) | |
# Handle previous hidden state if it exists | |
if exists(prev_hidden): | |
log_values = torch.cat((log_g(prev_hidden), log_values), dim=1) | |
log_coeffs = F.pad(log_coeffs, (0, 0, 1, 0)) | |
# Apply parallel scan in log space | |
out = heinsen_associative_scan_log(log_coeffs, log_values) | |
out = out[:, -x.shape[1]:] # Keep only the relevant sequence length | |
# Store last hidden state for potential return | |
next_prev_hidden = out[:, -1:] | |
# Apply output projection | |
out = self.to_out(out) | |
if not return_next_prev_hidden: | |
return out | |
return out, next_prev_hidden | |
class FeedForward(nn.Module): | |
def __init__(self, dim, mult=4): | |
super().__init__() | |
self.dim_inner = int(dim * mult) | |
self.net = nn.Sequential( | |
nn.Linear(dim, self.dim_inner), | |
nn.GELU(), | |
nn.Linear(self.dim_inner, dim) | |
) | |
def forward(self, x): | |
return self.net(x) | |
class CausalDepthWiseConv1d(nn.Module): | |
def __init__(self, dim, kernel_size): | |
super().__init__() | |
self.kernel_size = kernel_size | |
self.net = nn.Sequential( | |
nn.Conv1d(dim, dim, kernel_size = kernel_size, groups = dim), | |
nn.Conv1d(dim, dim, kernel_size = 1) | |
) | |
def forward(self, x): | |
x = x.transpose(1, 2) # b n d -> b d n | |
x = F.pad(x, (self.kernel_size - 1, 0), value = 0.) | |
x = self.net(x) | |
return x.transpose(1, 2) # b d n -> b n d | |
class RMSNorm(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.scale = dim ** 0.5 | |
self.gamma = nn.Parameter(torch.zeros(dim)) | |
def forward(self, x): | |
return F.normalize(x, dim=-1) * self.scale * (self.gamma + 1) | |
class MinGRU_Layers(nn.Module): | |
def __init__(self, dim, num_tokens): | |
super().__init__() | |
self.emb = nn.Embedding(num_tokens, dim) | |
self.casual_depth = CausalDepthWiseConv1d(dim=dim,kernel_size=3) | |
self.rms_norm = RMSNorm(dim) | |
self.gru = MinGRU(dim) | |
self.ff = FeedForward(dim) | |
self.norm = RMSNorm(dim) | |
self.to_logits = nn.Linear(dim, num_tokens, bias=False) | |
def forward(self, inputs, labels=None, is_first_layer=True, prev_hiddens=None): | |
if is_first_layer: | |
x = self.emb(inputs) | |
else: | |
x = self.emb(inputs.argmax(dim=-1)) | |
if exists(prev_hiddens): | |
x = x[:, -1:] | |
next_prev_hiddens = [] | |
prev_hiddens = iter(default(prev_hiddens, [])) | |
x = self.rms_norm(x) | |
prev_hidden = next(prev_hiddens, None) | |
min_gru_out, next_hidden = self.gru(x, prev_hidden, return_next_prev_hidden=True) | |
x = min_gru_out + x | |
next_prev_hiddens.append(next_hidden) | |
x = self.ff(x) + x | |
logits = self.to_logits(self.norm(x)) | |
if labels is not None: | |
loss = F.cross_entropy(logits.transpose(1, 2), labels) | |
else: | |
loss = None | |
return loss, logits, next_prev_hiddens | |
class MinGRU_LM(nn.Module): | |
def __init__(self, dim, num_tokens, num_layers): | |
super().__init__() | |
self.layers = nn.ModuleList([MinGRU_Layers(dim, num_tokens) for _ in range(num_layers)]) | |
def forward(self, inputs, labels): | |
total_loss = 0 | |
hidden_states = [None] * len(self.layers) | |
current_input = inputs | |
for i, layer in enumerate(self.layers): | |
loss, logits, next_hiddens = layer( | |
inputs=current_input, | |
labels=labels, | |
is_first_layer=(i == 0), | |
prev_hiddens=hidden_states[i] | |
) | |
if loss is not None: | |
total_loss += loss | |
current_input = logits # Use the logits as input for the next layer | |
hidden_states[i] = next_hiddens | |
return total_loss / len(self.layers), logits | |