File size: 1,424 Bytes
e9c77ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
from transformers import RoFormerModel, RoFormerPreTrainedModel


class RoFormerForSparseEmbedding(RoFormerPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.encoder = RoFormerModel(config)
        self.linear_layer = nn.Linear(config.hidden_size, 1)

        # Initialize weights and apply final processing
        self.post_init()

    def forward(self, input_ids, attention_mask, return_sparse=False):
        B, L = input_ids.shape

        last_hidden_states = self.encoder(input_ids, attention_mask)['last_hidden_state']  # [B,L,D]
        token_weights = self.linear_layer(last_hidden_states).squeeze(-1)  # [B,L]
        token_mask = (1 - attention_mask) * -1e4  # [B,L]
        token_mask[:, 0] = -1e4
        last_ind = torch.sum(attention_mask, -1, keepdim=True) - 1  # [B,1]
        token_mask = torch.scatter(token_mask, -1, last_ind, -1e4)  # [B,L]
        token_weights = token_weights + token_mask  # [B,L]

        emb = torch.zeros(B, L, self.encoder.config.vocab_size, dtype=token_weights.dtype,
                          device=token_weights.device)  # [B,L,V]
        emb = torch.scatter(emb, dim=-1, index=input_ids.unsqueeze(-1), src=token_weights.unsqueeze(-1))  # [B,L,V]
        emb = torch.max(torch.relu(emb), dim=-2).values  # [B,V]

        if return_sparse:
            emb = emb.to_sparse()

        return emb