Using item() makes torch.export not possible
#61
by
bartel97
- opened
Hi!
In the classes Phi3SuScaledRotaryEmbedding
and Phi3YarnScaledRotaryEmbedding
we see this kind of pattern:
def forward(self, x, position_ids, seq_len=None):
seq_len = torch.max(position_ids) + 1
if seq_len > self.original_max_position_embeddings:
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
else:
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
torch.export.export()
breaks at if seq_len > self.original_max_position_embeddings:
because seq_len
is data dependent and we try to branch on it. Two questions:
- Why are we recomputing
seq_len
if we pass it to the function? In all uses it is always passes. Seems to me that you can remove the optionality of it and not recompute it. I think this will not fix following problem in flash attention though. Here we branch again on a data dependency.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len)
- Is there no way of getting the sequence length without using
max()
on the tensor? If we need to use it, is there a way to branch on something else that is not data dependent?
For reference, I am using this https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html torch api and seeing the error described in https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit#heading=h.r02234kuof4f. The google docs also describes strategies that might help here.
nguyenbh
changed discussion status to
closed