BeardedMonster commited on
Commit
99ff7f8
1 Parent(s): 1a8950c

Upload GPTJXForCausalLM

Browse files
Files changed (1) hide show
  1. pretrained_model.py +1 -0
pretrained_model.py CHANGED
@@ -58,6 +58,7 @@ class CausalSelfAttention(nn.Module):
58
  if self.flash:
59
  if attn_mask is not None:
60
  # efficient attention using Flash Attention CUDA kernels
 
61
  y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0)
62
  else:
63
  y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
 
58
  if self.flash:
59
  if attn_mask is not None:
60
  # efficient attention using Flash Attention CUDA kernels
61
+ attn_mask = attn_mask.to(torch.bool)
62
  y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0)
63
  else:
64
  y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)