Use input attention mask instead of casual mask in attention
#74
by
CyberZHG
- opened
- modelling_RW.py +2 -2
modelling_RW.py
CHANGED
@@ -271,13 +271,14 @@ class Attention(nn.Module):
|
|
271 |
else:
|
272 |
present = None
|
273 |
|
|
|
274 |
if alibi is None:
|
275 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
276 |
key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
277 |
value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
278 |
|
279 |
attn_output = F.scaled_dot_product_attention(
|
280 |
-
query_layer_, key_layer_, value_layer_,
|
281 |
)
|
282 |
|
283 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
@@ -290,7 +291,6 @@ class Attention(nn.Module):
|
|
290 |
assert not output_attentions # not supported.
|
291 |
return outputs
|
292 |
else:
|
293 |
-
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
|
294 |
matmul_result = query_layer @ key_layer.transpose(-1, -2)
|
295 |
|
296 |
# change view to [batch_size, num_heads, q_length, kv_length]
|
|
|
271 |
else:
|
272 |
present = None
|
273 |
|
274 |
+
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(query_layer.dtype)
|
275 |
if alibi is None:
|
276 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
277 |
key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
278 |
value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
279 |
|
280 |
attn_output = F.scaled_dot_product_attention(
|
281 |
+
query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
|
282 |
)
|
283 |
|
284 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
|
|
291 |
assert not output_attentions # not supported.
|
292 |
return outputs
|
293 |
else:
|
|
|
294 |
matmul_result = query_layer @ key_layer.transpose(-1, -2)
|
295 |
|
296 |
# change view to [batch_size, num_heads, q_length, kv_length]
|