inf-wse-v1-base-zh / modeling_sparse.py
SamuelYang's picture
Upload 8 files
e9c77ac verified
raw
history blame contribute delete
No virus
1.42 kB
import torch
import torch.nn as nn
from transformers import RoFormerModel, RoFormerPreTrainedModel
class RoFormerForSparseEmbedding(RoFormerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.encoder = RoFormerModel(config)
self.linear_layer = nn.Linear(config.hidden_size, 1)
# Initialize weights and apply final processing
self.post_init()
def forward(self, input_ids, attention_mask, return_sparse=False):
B, L = input_ids.shape
last_hidden_states = self.encoder(input_ids, attention_mask)['last_hidden_state'] # [B,L,D]
token_weights = self.linear_layer(last_hidden_states).squeeze(-1) # [B,L]
token_mask = (1 - attention_mask) * -1e4 # [B,L]
token_mask[:, 0] = -1e4
last_ind = torch.sum(attention_mask, -1, keepdim=True) - 1 # [B,1]
token_mask = torch.scatter(token_mask, -1, last_ind, -1e4) # [B,L]
token_weights = token_weights + token_mask # [B,L]
emb = torch.zeros(B, L, self.encoder.config.vocab_size, dtype=token_weights.dtype,
device=token_weights.device) # [B,L,V]
emb = torch.scatter(emb, dim=-1, index=input_ids.unsqueeze(-1), src=token_weights.unsqueeze(-1)) # [B,L,V]
emb = torch.max(torch.relu(emb), dim=-2).values # [B,V]
if return_sparse:
emb = emb.to_sparse()
return emb