max model size / max seq length

#15
by paulmoonraker - opened

Hi,

I see that 'By default, input text longer than 384 word pieces is truncated'. However, in the tokenizer config I see model_max_length is 512. Does the model respect this? Or do i need to set the max seq length somewhere? Thanks,

Sentence Transformers org

Hello!

The model indeed respects the token length of 384 via this configuration setting: https://huggingface.co./sentence-transformers/all-mpnet-base-v2/blob/main/sentence_bert_config.json#L2
This parameter has priority over the tokenizer one. I do recognize that it is a bit confusing to have two separate values for the same setting in the model.

  • Tom Aarsen

than you for the response

is it possible to set the max_seq_length to 512 via transformers?

Sentence Transformers org

You could, but the performance of the model will likely be worse than if you kept it at 384. Feel free to experiment with it:

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, Pooling

transformer = Transformer("sentence-transformers/all-mpnet-base-v2", max_seq_length=512)
pooling = Pooling(transformer.get_word_embedding_dimension(), "mean")
model = SentenceTransformer(modules=[transformer, pooling])

embedding = model.encode("My text!")
print(embedding.shape)

Hi, is it possible to specify max_seq_length if we are using AutoTokenizer and AutoModel? I can pass max_length at tokenisation time, but I doubt that stops the model truncating at 384. Thank you

from transformers import AutoTokenizer, AutoModel
import torch

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask

#Sentences we want sentence embeddings for
sentences = ['This framework generates embeddings for each input sentence']

#Load AutoModel from huggingface model repository
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

#Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')

#Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input)

#Perform pooling. In this case, mean pooling
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

Some simple testing would indicate that the model processes up to 512. For the same text string, if I truncate at 384, and then again at greater than 384 but less than 512, I get a different vectors back. If I try greater than 512 the model throws an error.

I saw someone mentioned by default it should have a max_seq_length of 384, however, now it became 512, I did nothing but SentenceTransformer("all-mpnet-base-v2")

Could you explain why the default changed from 384 to 512?

SentenceTransformer("all-mpnet-base-v2")
output:
Transformer({'max_seq_length': 512, })

Sentence Transformers org

@paulmoonraker The model indeed crashes after 512, but was trained to work up to 384. 384 is recommended as the sequence length. You can set the max_length on tokenization-time like you've done, or with AutoTokenizer.from_pretrained("...", model_max_length=384) I believe.

@keyuchen2020 Huh, that is odd. It should indeed be 384. What version of sentence-transformers are you using? With 2.5.1 I get:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("all-mpnet-base-v2")
print(model)
SentenceTransformer(
  (0): Transformer({'max_seq_length': 384, 'do_lower_case': False}) with Transformer model: MPNetModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
)

i.e. max_seq_length of 384.

  • Tom Aarsen

Sign up or log in to comment