ammarnasr commited on
Commit
ffa08ea
1 Parent(s): cc9ca31

Upload model

Browse files
Files changed (1) hide show
  1. modeling_t5mimo.py +4 -1
modeling_t5mimo.py CHANGED
@@ -1420,6 +1420,7 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
1420
  output_attentions: Optional[bool] = None,
1421
  output_hidden_states: Optional[bool] = None,
1422
  return_dict: Optional[bool] = None,
 
1423
  ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
1424
  r"""
1425
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1517,6 +1518,9 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
1517
 
1518
  sequence_output = decoder_outputs[0]
1519
 
 
 
 
1520
  # Set device for model parallelism
1521
  if self.model_parallel:
1522
  torch.cuda.set_device(self.encoder.first_device)
@@ -1528,7 +1532,6 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
1528
  # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
1529
  sequence_output = sequence_output * (self.model_dim**-0.5)
1530
 
1531
- sequence_output = self.conv_block(sequence_output)
1532
  lm_logits = self.lm_head(sequence_output)
1533
 
1534
  loss = None
 
1420
  output_attentions: Optional[bool] = None,
1421
  output_hidden_states: Optional[bool] = None,
1422
  return_dict: Optional[bool] = None,
1423
+ use_conv: Optional[bool] = True,
1424
  ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
1425
  r"""
1426
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
1518
 
1519
  sequence_output = decoder_outputs[0]
1520
 
1521
+ if use_conv:
1522
+ sequence_output = self.conv_block(sequence_output)
1523
+
1524
  # Set device for model parallelism
1525
  if self.model_parallel:
1526
  torch.cuda.set_device(self.encoder.first_device)
 
1532
  # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
1533
  sequence_output = sequence_output * (self.model_dim**-0.5)
1534
 
 
1535
  lm_logits = self.lm_head(sequence_output)
1536
 
1537
  loss = None