Setting num_return_sequences results in shape mismatch error.

#28
by Watarungurunnn - opened

hf_args:
do_sample: true
temperature: 0.8
top_k: 50
top_p: 0.95
num_return_sequences: 30

model = AutoModelForCausalLM.from_pretrained(
            model_path,
            trust_remote_code=True,
            torch_dtype="auto",
            device_map="auto",
            **model_args,
        )
generated_tokens = model.generate(
                    inputs=input_ids,
                    pad_token_id=tokenizer.pad_token_id,
                    **hf_args,
                )

Error:

  File "/home/user_2942/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user_2942/.local/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 1068, in forward
    outputs = self.model(
  File "/home/user_2942/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user_2942/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user_2942/.local/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 908, in forward
    layer_outputs = decoder_layer(
  File "/home/user_2942/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user_2942/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user_2942/.local/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 650, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/user_2942/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user_2942/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user_2942/.local/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 252, in forward
    key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  File "/home/user_2942/.local/lib/python3.10/site-packages/transformers/cache_utils.py", line 1227, in update
    return update_fn(
  File "/home/user_2942/.local/lib/python3.10/site-packages/transformers/cache_utils.py", line 1202, in _static_update
    k_out[:, :, cache_position] = key_states
RuntimeError: shape mismatch: value tensor of shape [30, 16, 942, 128] cannot be broadcast to indexing result of shape [1, 16, 942, 128]
time="2024-07-09T04:45:52 UTC" level=info msg="sub-process exited" argo=true error="<nil>"
Error: exit status 1

Hi,
Changing line 1767 in generation/utils.py to getattr(generation_config, "num_beams", 1) * getattr(generation_config, "num_return_sequences", 1) * batch_size , fixed the problem for me. Hope you find that helpful :)

Any update on this issue? I cannot use a fork of transformers on my project.

It was fixed and released already, just make sure to update transformers 😄

Google org

Hi @Watarungurunnn , Could you please confirm if you are facing still issue after updated the transformers, let us know will assist you or else we can close this issue.

Thank you.

Seems fixed! Thank you

Watarungurunnn changed discussion status to closed

Sign up or log in to comment