import math import warnings from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast import torch import torch.nn as nn import torch.nn.functional as F import transformers from packaging import version from torch.utils.data.dataloader import DataLoader from tqdm import tqdm from transformers.cache_utils import Cache from transformers.modeling_outputs import ( BaseModelOutputWithPooling, ModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) from transformers.modeling_utils import PreTrainedModel from transformers.models.auto import AutoModel, AutoModelForSequenceClassification, AutoModelForTokenClassification from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder from transformers.tokenization_utils import BatchEncoding from .configuration_nllbllm2vec import NLLBLLM2VecConfig from .modeling_llama_encoder import LlamaEncoderModel DEFAULT_TOKENIZE_KWARGS = { "padding": True, "truncation": True, "max_length": 512, "return_tensors": "pt", } DEFAULT_DATALOADER_KWARGS = { "shuffle": False, "batch_size": 32, "pin_memory": True, } def default_collate_fn_closure(tokenizer, tokenize_kwargs) -> Callable: def collate_fn(batch: list[str]) -> BatchEncoding: return tokenizer(batch, **tokenize_kwargs) return collate_fn def defaulter(kwd_dict: Optional[Dict], default_dict: Dict) -> Dict: return default_dict if kwd_dict is None else {**default_dict, **kwd_dict} @dataclass class SequenceClassifierOutputWithPastAndPooler(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None pooler_output: torch.FloatTensor = None class NLLBLLM2Vec(PreTrainedModel): config_class = NLLBLLM2VecConfig model_type = "nllb-llm2vec" _supports_flash_attn_2 = True _supports_sdpa = True """ NLLBLLM2Vec model combining NLLB and LLama encoders. Args: config (Optional[NLLBLLM2VecConfig]): Configuration object. nllb_encoder (Optional[M2M100Encoder]): Pre-initialized NLLB encoder. llm2vec (Optional[LlamaEncoderModel]): Pre-initialized LLama encoder. *inputs: Additional positional arguments. **kwargs: Additional keyword arguments. """ def __init__( self, config: Optional[NLLBLLM2VecConfig] = None, nllb_encoder: Optional[M2M100Encoder] = None, llm2vec: Optional[LlamaEncoderModel] = None, *inputs, **kwargs, ): # Ensure that either config is not None or both encoders are provided if config is None and (nllb_encoder is None or llm2vec is None): raise ValueError( "Either `config` must be provided, or both `nllb_encoder` and `llm2vec` must be specified." ) if config is not None: super().__init__(config, *inputs, **kwargs) # from_pretrained overwrites this after config instantiation, so we make sure it's correctly set config.nllb_config._attn_implementation = config._attn_implementation config.llm2vec_config._attn_implementation = config._attn_implementation self.nllb_encoder = nllb_encoder or M2M100Encoder(config.nllb_config) self.llm2vec = llm2vec or LlamaEncoderModel(config.llm2vec_config) self.config = config else: # Both encoders are provided self.nllb_encoder = cast(M2M100Encoder, nllb_encoder) self.llm2vec = cast(LlamaEncoderModel, llm2vec) self.config = NLLBLLM2VecConfig( nllb_config=self.nllb_encoder.config, # type: ignore llm2vec_config=self.llm2vec.config, # type: ignore ) super().__init__(self.config, *inputs, **kwargs) self.up_proj = nn.Linear( self.nllb_encoder.config.d_model, self.llm2vec.config.hidden_size, bias=False, ) # TODO: update this once commit is included min_version = "4.46.0" if self.config.nllb_config._attn_implementation == "flash_attention_2": if version.parse(transformers.__version__) < version.parse(min_version): warnings.warn( f"Installed transformers version ({transformers.__version__}) never sets NLLB-encoder dropout to `False` with FlashAttention2. See https://github.com/huggingface/transformers/pull/33844 for more info. Consider upgrading to latest to {min_version} or master.", UserWarning, ) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, indices: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ) -> BaseModelOutputWithPooling: """ Forward pass of the model. Args: input_ids (torch.Tensor): Input token IDs. attention_mask (torch.Tensor): Attention mask. indices (Optional[Tuple[torch.Tensor, torch.Tensor]]): Precomputed input indices and offsets. Returns: BaseModelOutputWithPooling: Model outputs with last hidden state and pooled output. """ # Compute input indices and offsets if not provided if indices is None: seq_indices, seq_offsets = self._get_input_offsets(attention_mask) else: seq_indices, seq_offsets = indices nllb_outputs = self.nllb_encoder( input_ids=input_ids, attention_mask=attention_mask, ) nllb_last_hidden_state = nllb_outputs.last_hidden_state nllb_last_hidden_state = self.up_proj(nllb_last_hidden_state) outputs = self.llm2vec( inputs_embeds=nllb_last_hidden_state, attention_mask=attention_mask, ) pooler_output = self._mean_embedding( hidden_states=outputs.last_hidden_state, input_indices=seq_indices, offsets=seq_offsets, ) return BaseModelOutputWithPooling( last_hidden_state=outputs.last_hidden_state, pooler_output=pooler_output, ) @property def tokenizer(self): """ Get the tokenizer associated with the model. Returns: PreTrainedTokenizer: The tokenizer instance. """ if not hasattr(self, "_tokenizer"): from transformers import AutoTokenizer self._tokenizer = AutoTokenizer.from_pretrained( "facebook/nllb-200-distilled-600M", padding_side="right" ) return self._tokenizer def encode( self, inputs: List[str], src_lang: str = "eng_Latn", dataloader_kwargs: Optional[Dict[str, Any]] = None, tokenize_kwargs: Optional[Dict[str, Any]] = None, collate_fn_closure: Optional[Callable] = None, ) -> torch.Tensor: """ Encode input texts into embeddings. Args: inputs (List[str]): List of input texts. src_lang (str): Source language code for the tokenizer (default: `"eng_Latn"`). dataloader_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the dataloader excl. `collate_fn`. Defaults to: >> dataloader_kwargs = { >> "shuffle": False, >> "pin_memory": True, >> } tokenize_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the tokenizer. Defaults to: >> tokenize_kwargs = { >> "padding": True, >> "truncation": True, >> "max_length": 512, >> "return_tensors": "pt", >> } collate_fn_closure (Optional[Callable]): Closure that should return a `collate_fn`. Defaults to: >> def default_collate_fn_closure(tokenizer, tokenize_kwargs) -> Callable: >> def collate_fn(batch: list[str]) -> BatchEncoding: >> return tokenizer(batch, **tokenize_kwargs) >> return collate_fn Returns: torch.Tensor: Mean-pooled sequence embeddings of the inputs. """ # merge user kwargs with defaults, giving priority to user kwargs tokenize_kwargs = defaulter(tokenize_kwargs, DEFAULT_TOKENIZE_KWARGS) dataloader_kwargs = defaulter(dataloader_kwargs, DEFAULT_DATALOADER_KWARGS) tokenizer = self.tokenizer tokenizer.src_lang = src_lang device = next(self.parameters()).device if collate_fn_closure is None: collate_fn = default_collate_fn_closure(tokenizer, tokenize_kwargs) else: collate_fn = collate_fn_closure(tokenizer, tokenize_kwargs) assert ( "collate_fn" not in dataloader_kwargs ), "`collate_fn` should be created via `collate_fn_closure`" self.eval() if len(inputs) > dataloader_kwargs.get("batch_size", 1): dataloader = DataLoader(inputs, collate_fn=collate_fn, **dataloader_kwargs) # type: ignore all_embeddings = [] # Iterate through the dataloader with a progress bar and autocast with torch.autocast(device_type=device.type, dtype=torch.bfloat16): for batch in tqdm(dataloader, desc="Encoding"): # Move batch to device batch = {k: v.to(device) for k, v in batch.items()} # Forward pass through the model (assumes model returns embeddings) with torch.inference_mode(): pooled_embeddings = cast( SequenceClassifierOutputWithPastAndPooler, self(**batch) ).pooler_output # Assuming model returns sequence embeddings all_embeddings.append(pooled_embeddings) # Concatenate all pooled embeddings along the batch dimension all_embeddings = torch.cat(all_embeddings, dim=0) else: batch = {k: v.to(device) for k, v in collate_fn(inputs).items()} with torch.inference_mode(): all_embeddings = cast( SequenceClassifierOutputWithPastAndPooler, self(**batch) ).pooler_output # Assuming model returns sequence embeddings return all_embeddings @staticmethod def _get_input_offsets( attention_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute indices and offsets for mean pooling using EmbeddingBag. Args: attention_mask (torch.Tensor): Attention mask of shape (batch_size, seq_len). Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - input_indices: Indices of non-padded tokens in the flattened input. - offsets: Offsets indicating the start index of each sequence in the flattened input. """ # Find the indices of non-padded tokens in flattened hidden_states input_indices = attention_mask.view(-1).nonzero(as_tuple=False).squeeze() # Compute the offsets: for each sequence, where it starts in the flattened input non_padded_lengths = attention_mask.sum( dim=1 ) # Count non-padded tokens per sequence offsets = non_padded_lengths.cumsum(dim=0).roll(shifts=1) offsets[0] = 0 return input_indices, offsets @staticmethod def _mean_embedding( hidden_states: torch.Tensor, input_indices: torch.Tensor, offsets: torch.Tensor, ) -> torch.Tensor: """ Compute the mean of non-padded embeddings using `embedding_bag`, properly handling padding with offsets. Args: hidden_states (torch.Tensor): Hidden states of shape (batch_size, seq_len, embed_dim). input_indices (torch.Tensor): Indices of non-padded tokens in flattened form. offsets (torch.Tensor): Offsets specifying the start of each sequence. Returns: torch.Tensor: Pooled mean embeddings of shape (batch_size, embed_dim). """ # Flatten hidden_states to 2D: shape (batch_size * seq_len, embedding_dim) batch_size, seq_len, embed_dim = hidden_states.shape token_embeds = hidden_states.view(-1, embed_dim) # Use embedding_bag with mode 'mean' and appropriate indices return F.embedding_bag( input=input_indices, # Indices of non-padded tokens in flattened form weight=token_embeds, # The flattened hidden states as embedding matrix offsets=offsets, # Offsets specifying start of each sequence mode="mean", # Aggregation mode ) class NLLBLLM2VecForSequenceClassification(PreTrainedModel): config_class = NLLBLLM2VecConfig model_type = "nllb-llm2vec" base_model_prefix = "model" _supports_flash_attn_2 = True _supports_sdpa = True def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = NLLBLLM2Vec(config) self.score = nn.Linear( config.llm2vec_config.hidden_size, self.num_labels, bias=False ) # Initialize weights and apply final processing self.post_init() def _init_weights(self, module): if module is self.score: # INFO: # - critical that clf head is in float32 (NusaX perf. drops funky otherwise) # - Initialization needs to be redone, otherwise borked # - Use kaiming uniform, b/c Llama init (cf. `nn.Linear` below) performs worse self.score = self.score.to(torch.float32) torch.nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) elif isinstance(module, nn.Linear): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def get_input_embeddings(self): return self.model.nllb.embed_tokens def set_input_embeddings(self, value): self.model.nllb.embed_tokens = value def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) transformer_outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs.pooler_output pooled_logits = self.score(hidden_states) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and ( labels.dtype == torch.long or labels.dtype == torch.int ): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": if self.num_labels == 1: loss = F.mse_loss(pooled_logits.squeeze(), labels.squeeze()) else: loss = F.mse_loss(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss = F.cross_entropy( pooled_logits.view(-1, self.num_labels), labels.view(-1) ) elif self.config.problem_type == "multi_label_classification": loss = F.binary_cross_entropy_with_logits(pooled_logits, labels) if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPastAndPooler( loss=loss, hidden_states=hidden_states, logits=pooled_logits, pooler_output=transformer_outputs.pooler_output, ) class NLLBLLM2VecForTokenClassification(PreTrainedModel): config_class = NLLBLLM2VecConfig model_type = "nllb-llm2vec" base_model_prefix = "model" _supports_flash_attn_2 = True _supports_sdpa = True def __init__(self, config: NLLBLLM2VecConfig): super().__init__(config) self.num_labels = config.num_labels self.model = NLLBLLM2Vec(config) self.classifier = nn.Linear( config.llm2vec_config.hidden_size, self.num_labels, bias=False ) # Initialize weights and apply final processing self.post_init() def _init_weights(self, module): if module is self.classifier: # INFO: # - critical that clf head is in float32 (NusaX perf. drops funky otherwise) # - Initialization needs to be redone, otherwise borked # - Use kaiming uniform, b/c Llama init (cf. `nn.Linear` below) performs worse self.classifier = self.classifier.to(torch.float32) torch.nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) elif isinstance(module, nn.Linear): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def get_input_embeddings(self): return self.model.nllb.embed_tokens def set_input_embeddings(self, value): self.model.nllb.embed_tokens = value # adapted from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification # - removed classifier dropout # - use F.cross_entropy def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.model( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] logits = self.classifier(sequence_output) loss = None if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(logits.device) loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) AutoModel.register(NLLBLLM2VecConfig, NLLBLLM2Vec) AutoModelForSequenceClassification.register( NLLBLLM2VecConfig, NLLBLLM2VecForSequenceClassification ) AutoModelForTokenClassification.register( NLLBLLM2VecConfig, NLLBLLM2VecForTokenClassification )