BeardedMonster
commited on
Commit
•
99ff7f8
1
Parent(s):
1a8950c
Upload GPTJXForCausalLM
Browse files- pretrained_model.py +1 -0
pretrained_model.py
CHANGED
@@ -58,6 +58,7 @@ class CausalSelfAttention(nn.Module):
|
|
58 |
if self.flash:
|
59 |
if attn_mask is not None:
|
60 |
# efficient attention using Flash Attention CUDA kernels
|
|
|
61 |
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0)
|
62 |
else:
|
63 |
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
|
|
|
58 |
if self.flash:
|
59 |
if attn_mask is not None:
|
60 |
# efficient attention using Flash Attention CUDA kernels
|
61 |
+
attn_mask = attn_mask.to(torch.bool)
|
62 |
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0)
|
63 |
else:
|
64 |
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
|