Upload model
Browse files- modeling_t5mimo.py +4 -1
modeling_t5mimo.py
CHANGED
@@ -1420,6 +1420,7 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
|
|
1420 |
output_attentions: Optional[bool] = None,
|
1421 |
output_hidden_states: Optional[bool] = None,
|
1422 |
return_dict: Optional[bool] = None,
|
|
|
1423 |
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
1424 |
r"""
|
1425 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
@@ -1517,6 +1518,9 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
|
|
1517 |
|
1518 |
sequence_output = decoder_outputs[0]
|
1519 |
|
|
|
|
|
|
|
1520 |
# Set device for model parallelism
|
1521 |
if self.model_parallel:
|
1522 |
torch.cuda.set_device(self.encoder.first_device)
|
@@ -1528,7 +1532,6 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
|
|
1528 |
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
1529 |
sequence_output = sequence_output * (self.model_dim**-0.5)
|
1530 |
|
1531 |
-
sequence_output = self.conv_block(sequence_output)
|
1532 |
lm_logits = self.lm_head(sequence_output)
|
1533 |
|
1534 |
loss = None
|
|
|
1420 |
output_attentions: Optional[bool] = None,
|
1421 |
output_hidden_states: Optional[bool] = None,
|
1422 |
return_dict: Optional[bool] = None,
|
1423 |
+
use_conv: Optional[bool] = True,
|
1424 |
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
1425 |
r"""
|
1426 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
1518 |
|
1519 |
sequence_output = decoder_outputs[0]
|
1520 |
|
1521 |
+
if use_conv:
|
1522 |
+
sequence_output = self.conv_block(sequence_output)
|
1523 |
+
|
1524 |
# Set device for model parallelism
|
1525 |
if self.model_parallel:
|
1526 |
torch.cuda.set_device(self.encoder.first_device)
|
|
|
1532 |
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
1533 |
sequence_output = sequence_output * (self.model_dim**-0.5)
|
1534 |
|
|
|
1535 |
lm_logits = self.lm_head(sequence_output)
|
1536 |
|
1537 |
loss = None
|