The current checkpoint doesn't use group query attention.

#3
by yaya-sy - opened

When I tried to load the model using:

llm = AutoModelForCausalLM.from_pretrained("lelapa/InkubaLM-0.4B",
                                       torch_dtype=torch.float16)

I encountered the following error:

RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM:
    size mismatch for model.layers.0.self_attn.k_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.0.self_attn.v_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.1.self_attn.k_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.1.self_attn.v_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.2.self_attn.k_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.2.self_attn.v_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.3.self_attn.k_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.3.self_attn.v_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.4.self_attn.k_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.4.self_attn.v_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.5.self_attn.k_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.5.self_attn.v_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.6.self_attn.k_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.6.self_attn.v_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.7.self_attn.k_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    size mismatch for model.layers.7.self_attn.v_proj.weight: copying a param with shape torch.Size([2048, 2048]) from checkpoint, the shape in current model is torch.Size([256, 2048]).
    You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

This error suggests that the current checkpoint uses a standard Multi-Head Attention instead of Group Query Attention, as the k and v matrices are square. To fix this issue, I modified the config.json file by setting num_key_value_heads = num_attention_heads = 32.
This is the purpose of this pull request.

yaya-sy changed pull request title from The actual checkpoint doesn't use group query attention. to The current checkpoint doesn't use group query attention.
Lelapa AI org

Hello, when loading the model, add trust_remote_code=True

e.g


llm = AutoModelForCausalLM.from_pretrained("lelapa/InkubaLM-0.4B", torch_dtype=torch.float16, trust_remote_code=True)
Atnafu changed pull request status to merged

Sign up or log in to comment