ammarnasr commited on
Commit
248a174
1 Parent(s): ffa08ea

Upload model

Browse files
Files changed (3) hide show
  1. config.json +1 -0
  2. configuration_t5mimo.py +2 -0
  3. modeling_t5mimo.py +86 -95
config.json CHANGED
@@ -18,6 +18,7 @@
18
  "initializer_factor": 0.05,
19
  "is_encoder_decoder": true,
20
  "is_gated_act": false,
 
21
  "layer_norm_epsilon": 1e-06,
22
  "model_type": "t5mimo",
23
  "num_decoder_layers": 4,
 
18
  "initializer_factor": 0.05,
19
  "is_encoder_decoder": true,
20
  "is_gated_act": false,
21
+ "is_mimo": true,
22
  "layer_norm_epsilon": 1e-06,
23
  "model_type": "t5mimo",
24
  "num_decoder_layers": 4,
configuration_t5mimo.py CHANGED
@@ -81,6 +81,7 @@ class T5MIMOConfig(PretrainedConfig):
81
  classifier_dropout=0.0,
82
  num_seqs=3,
83
  num_filters=64,
 
84
  **kwargs,
85
  ):
86
  self.vocab_size = vocab_size
@@ -102,6 +103,7 @@ class T5MIMOConfig(PretrainedConfig):
102
  self.use_cache = use_cache
103
  self.num_seqs = num_seqs
104
  self.num_filters = num_filters
 
105
 
106
  act_info = self.feed_forward_proj.split("-")
107
  self.dense_act_fn = act_info[-1]
 
81
  classifier_dropout=0.0,
82
  num_seqs=3,
83
  num_filters=64,
84
+ is_mimo=True,
85
  **kwargs,
86
  ):
87
  self.vocab_size = vocab_size
 
103
  self.use_cache = use_cache
104
  self.num_seqs = num_seqs
105
  self.num_filters = num_filters
106
+ self.is_mimo = is_mimo
107
 
108
  act_info = self.feed_forward_proj.split("-")
109
  self.dense_act_fn = act_info[-1]
modeling_t5mimo.py CHANGED
@@ -198,8 +198,9 @@ class T5Attention(nn.Module):
198
  self.d_model = config.d_model
199
  self.key_value_proj_dim = config.d_kv
200
  self.n_heads = config.num_heads
201
- self.dropout = config.dropout_rate
202
  self.inner_dim = self.n_heads * self.key_value_proj_dim
 
 
203
 
204
  # Mesh TensorFlow initialization to avoid scaling before softmax
205
  self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
@@ -276,7 +277,7 @@ class T5Attention(nn.Module):
276
  relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
277
  return relative_buckets
278
 
279
- def compute_bias(self, query_length, key_length,multivar_dim=-1, device=None):
280
  """Compute binned relative position bias"""
281
  if device is None:
282
  device = self.relative_attention_bias.weight.device
@@ -291,9 +292,8 @@ class T5Attention(nn.Module):
291
  )
292
  values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
293
  values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
294
- if multivar_dim !=-1: # shape (1, multivar_dim, num_heads, query_length, key_length) (copy across)
295
- values = values.expand(1, multivar_dim, -1, -1, -1)
296
-
297
  return values
298
 
299
  def forward(
@@ -314,42 +314,41 @@ class T5Attention(nn.Module):
314
  # Input is (batch_size, seq_length, dim)
315
  # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
316
  # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
317
- if len(hidden_states.shape) == 3:
318
- batch_size, seq_length = hidden_states.shape[:2]
319
  else:
320
- batch_size, seq_length = hidden_states.shape[0],hidden_states.shape[2]
321
- multivar_dim = hidden_states.shape[1]
322
  real_seq_length = seq_length
323
-
324
  if past_key_value is not None:
325
  if len(past_key_value) != 2:
326
- raise ValueError(
327
- f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
328
- )
329
- real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
 
330
 
331
- if len(hidden_states.shape) == 3:
332
- key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
333
- else:
334
  key_length = real_seq_length if key_value_states is None else key_value_states.shape[2]
 
 
 
335
 
336
 
337
  def shape(states):
338
  """projection"""
339
- # states: torch.Size([3, 16, 512]) -> query_states: torch.Size([3, 8, 16, 64])
340
- # states: torch.Size([3, 6, 16, 512]) -> query_states: torch.Size([3, 6, 8 , 16, 64])
341
- if len(states.shape) == 3:
342
- return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
343
- else:
344
  return states.view(batch_size, multivar_dim, -1, self.n_heads, self.key_value_proj_dim).transpose(2, 3)
 
 
345
 
346
 
347
  def unshape(states):
348
  """reshape"""
349
- if len(states.shape) == 4:
350
- return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
351
- else:
352
  return states.transpose(2, 3).contiguous().view(batch_size, multivar_dim, -1, self.inner_dim)
 
 
353
 
354
  def project(hidden_states, proj_layer, key_value_states, past_key_value):
355
  """projects hidden states correctly to key/query states"""
@@ -361,12 +360,14 @@ class T5Attention(nn.Module):
361
  # cross-attn
362
  # (batch_size, n_heads, seq_length, dim_per_head)
363
  hidden_states = shape(proj_layer(key_value_states))
364
-
365
  if past_key_value is not None:
366
  if key_value_states is None:
367
  # self-attn
368
  # (batch_size, n_heads, key_length, dim_per_head)
369
- hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
 
 
 
370
  elif past_key_value.shape[2] != key_value_states.shape[1]:
371
  # checking that the `sequence_length` of the `past_key_value` is the same as
372
  # the provided `key_value_states` to support prefix tuning
@@ -393,14 +394,10 @@ class T5Attention(nn.Module):
393
 
394
 
395
  # compute scores
396
- if len(hidden_states.shape) == 3:
397
- scores = torch.matmul(
398
- query_states, key_states.transpose(3, 2)
399
- ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
400
  else:
401
- scores = torch.matmul(
402
- query_states, key_states.transpose(4, 3)
403
- )
404
 
405
 
406
 
@@ -408,28 +405,22 @@ class T5Attention(nn.Module):
408
 
409
  if position_bias is None:
410
  if not self.has_relative_attention_bias:
411
-
412
- if len(hidden_states.shape) == 3:
413
- position_bias = torch.zeros(
414
- (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
415
- )
416
  else:
417
- position_bias = torch.zeros(
418
- (1,multivar_dim, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
419
- )
420
  if self.gradient_checkpointing and self.training:
421
  position_bias.requires_grad = True
422
  else:
423
-
424
- if len(hidden_states.shape) == 3:
425
- position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
426
- else:
427
- position_bias = self.compute_bias(real_seq_length, key_length,multivar_dim=multivar_dim, device=scores.device)
428
 
429
  # if key and values are already calculated
430
  # we want only the last query position bias
431
  if past_key_value is not None:
432
- position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
 
 
 
433
 
434
  if mask is not None:
435
  position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
@@ -443,24 +434,16 @@ class T5Attention(nn.Module):
443
  else:
444
  position_bias_masked = position_bias
445
 
446
-
447
  scores += position_bias_masked
448
- attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
449
- scores
450
- ) # (batch_size, n_heads, seq_length, key_length)
451
- attn_weights = nn.functional.dropout(
452
- attn_weights, p=self.dropout, training=self.training
453
- ) # (batch_size, n_heads, seq_length, key_length)
454
 
455
  # Mask heads if we want to
456
  if layer_head_mask is not None:
457
  attn_weights = attn_weights * layer_head_mask
458
 
459
 
460
- if len(hidden_states.shape) == 3:
461
- attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
462
- else:
463
- attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, multivar_dim, seq_length, dim)
464
  attn_output = self.o(attn_output)
465
 
466
 
@@ -526,7 +509,6 @@ class T5LayerCrossAttention(nn.Module):
526
  query_length=None,
527
  output_attentions=False,
528
  ):
529
-
530
  normed_hidden_states = self.layer_norm(hidden_states)
531
  attention_output = self.EncDecAttention(
532
  normed_hidden_states,
@@ -555,6 +537,8 @@ class T5Block(nn.Module):
555
 
556
  self.layer.append(T5LayerFF(config))
557
 
 
 
558
  def forward(
559
  self,
560
  hidden_states,
@@ -613,7 +597,10 @@ class T5Block(nn.Module):
613
  # the actual query length is unknown for cross attention
614
  # if using past key value states. Need to inject it here
615
  if present_key_value_state is not None:
616
- query_length = present_key_value_state[0].shape[2]
 
 
 
617
  else:
618
  query_length = None
619
 
@@ -885,19 +872,14 @@ class T5Stack(T5PreTrainedModel):
885
  self.embed_tokens = self.embed_tokens.to(self.first_device)
886
  use_cache = use_cache if use_cache is not None else self.config.use_cache
887
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
888
- output_hidden_states = (
889
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
890
- )
891
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
892
 
893
  if input_ids is not None and inputs_embeds is not None:
894
  err_msg_prefix = "decoder_" if self.is_decoder else ""
895
- raise ValueError(
896
- f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
897
- )
898
  elif input_ids is not None:
899
  input_shape = input_ids.size()
900
- # input_ids = input_ids.view(-1, input_shape[-1])
901
  elif inputs_embeds is not None:
902
  input_shape = inputs_embeds.size()[:-1]
903
  else:
@@ -909,13 +891,16 @@ class T5Stack(T5PreTrainedModel):
909
  raise ValueError("You have to initialize the model with valid token embeddings")
910
  inputs_embeds = self.embed_tokens(input_ids)
911
 
912
- if len(input_shape) == 3:
913
  batch_size, multivar_seqs ,seq_length = input_shape
914
  else:
915
  batch_size, seq_length = input_shape
916
 
917
  # required mask seq length can be calculated via length of past
918
- mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
 
 
 
919
 
920
  if use_cache is True:
921
  if not self.is_decoder:
@@ -926,45 +911,34 @@ class T5Stack(T5PreTrainedModel):
926
  past_key_values = [None] * len(self.block)
927
 
928
  if attention_mask is None:
929
- if len(input_shape) == 2:
930
- attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
931
- else:
932
- attention_mask = torch.ones(batch_size, multivar_seqs, mask_seq_length, device=inputs_embeds.device)
933
 
934
 
935
 
936
  # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
937
  # ourselves in which case we just need to make it broadcastable to all heads.
938
- if len(input_shape) == 2:
939
- extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
 
940
  else:
941
  extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
942
- # permute from [batch_size, 1, multivar_seqs, seq_length] to [batch_size, multivar_seqs, 1, seq_length]
943
- extended_attention_mask = extended_attention_mask.permute(0, 2, 1, 3)
944
- # Now make it [batch_size, multivar_seqs, 1, 1, seq_length]
945
- extended_attention_mask = extended_attention_mask.unsqueeze(3)
946
 
947
  # If a 2D or 3D attention mask is provided for the cross-attention
948
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
949
  if self.is_decoder and encoder_hidden_states is not None:
950
-
951
- if len(encoder_hidden_states.size()) == 3 :
952
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
953
- else:
954
  encoder_batch_size, multivar_dem, encoder_sequence_length, _ = encoder_hidden_states.size()
 
 
955
 
956
  encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
957
  if encoder_attention_mask is None:
958
- encoder_attention_mask = torch.ones(
959
- encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
960
- )
961
- if len(input_shape) == 2:
962
  encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
 
963
  else:
964
  encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
965
- multivar_dim = extended_attention_mask.shape[1]
966
- encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(1)
967
- encoder_extended_attention_mask = encoder_extended_attention_mask.permute(0, 3, 1, 2, 4)
968
 
969
  else:
970
  encoder_extended_attention_mask = None
@@ -973,9 +947,7 @@ class T5Stack(T5PreTrainedModel):
973
 
974
  if self.gradient_checkpointing and self.training:
975
  if use_cache:
976
- logger.warning_once(
977
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
978
- )
979
  use_cache = False
980
 
981
  # Prepare head mask if needed
@@ -1453,6 +1425,8 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
1453
  >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
1454
  >>> # studies have shown that owning a dog is good for you.
1455
  ```"""
 
 
1456
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1457
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1458
 
@@ -1461,6 +1435,8 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
1461
  if self.config.num_layers == self.config.num_decoder_layers:
1462
  decoder_head_mask = head_mask
1463
 
 
 
1464
  # Encode if needed (training, first prediction pass)
1465
  if encoder_outputs is None:
1466
  # Convert encoder inputs in embeddings if needed
@@ -1500,6 +1476,15 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
1500
  if decoder_attention_mask is not None:
1501
  decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1502
 
 
 
 
 
 
 
 
 
 
1503
  # Decode
1504
  decoder_outputs = self.decoder(
1505
  input_ids=decoder_input_ids,
@@ -1518,6 +1503,7 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
1518
 
1519
  sequence_output = decoder_outputs[0]
1520
 
 
1521
  if use_conv:
1522
  sequence_output = self.conv_block(sequence_output)
1523
 
@@ -1548,8 +1534,11 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
1548
  if not return_dict:
1549
  output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
1550
  return ((loss,) + output) if loss is not None else output
 
 
 
1551
 
1552
- return Seq2SeqLMOutput(
1553
  loss=loss,
1554
  logits=lm_logits,
1555
  past_key_values=decoder_outputs.past_key_values,
@@ -1560,6 +1549,7 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
1560
  encoder_hidden_states=encoder_outputs.hidden_states,
1561
  encoder_attentions=encoder_outputs.attentions,
1562
  )
 
1563
 
1564
  def prepare_inputs_for_generation(
1565
  self,
@@ -1640,6 +1630,7 @@ class T5MIMOEncoderModel(T5PreTrainedModel):
1640
 
1641
  def __init__(self, config: T5MIMOConfig):
1642
  super().__init__(config)
 
1643
  self.shared = nn.Embedding(config.vocab_size, config.d_model)
1644
 
1645
  encoder_config = copy.deepcopy(config)
 
198
  self.d_model = config.d_model
199
  self.key_value_proj_dim = config.d_kv
200
  self.n_heads = config.num_heads
 
201
  self.inner_dim = self.n_heads * self.key_value_proj_dim
202
+ self.dropout = config.dropout_rate
203
+ self.config = config
204
 
205
  # Mesh TensorFlow initialization to avoid scaling before softmax
206
  self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
 
277
  relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
278
  return relative_buckets
279
 
280
+ def compute_bias(self, query_length, key_length, device=None):
281
  """Compute binned relative position bias"""
282
  if device is None:
283
  device = self.relative_attention_bias.weight.device
 
292
  )
293
  values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
294
  values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
295
+ if self.config.is_mimo:
296
+ values = values.unsqueeze(0)# shape (1, 1, num_heads, query_length, key_length)
 
297
  return values
298
 
299
  def forward(
 
314
  # Input is (batch_size, seq_length, dim)
315
  # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
316
  # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
317
+ if self.config.is_mimo:
318
+ batch_size, multivar_dim, seq_length = hidden_states.shape[:3]
319
  else:
320
+ batch_size, seq_length = hidden_states.shape[:2]
 
321
  real_seq_length = seq_length
322
+
323
  if past_key_value is not None:
324
  if len(past_key_value) != 2:
325
+ raise ValueError(f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states")
326
+ if self.config.is_mimo:
327
+ real_seq_length += past_key_value[0].shape[3] if query_length is None else query_length
328
+ else:
329
+ real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
330
 
331
+ if self.config.is_mimo:
 
 
332
  key_length = real_seq_length if key_value_states is None else key_value_states.shape[2]
333
+ else:
334
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
335
+
336
 
337
 
338
  def shape(states):
339
  """projection"""
340
+ if self.config.is_mimo:
 
 
 
 
341
  return states.view(batch_size, multivar_dim, -1, self.n_heads, self.key_value_proj_dim).transpose(2, 3)
342
+ else:
343
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
344
 
345
 
346
  def unshape(states):
347
  """reshape"""
348
+ if self.config.is_mimo:
 
 
349
  return states.transpose(2, 3).contiguous().view(batch_size, multivar_dim, -1, self.inner_dim)
350
+ else:
351
+ return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
352
 
353
  def project(hidden_states, proj_layer, key_value_states, past_key_value):
354
  """projects hidden states correctly to key/query states"""
 
360
  # cross-attn
361
  # (batch_size, n_heads, seq_length, dim_per_head)
362
  hidden_states = shape(proj_layer(key_value_states))
 
363
  if past_key_value is not None:
364
  if key_value_states is None:
365
  # self-attn
366
  # (batch_size, n_heads, key_length, dim_per_head)
367
+ if self.config.is_mimo:
368
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=3)
369
+ else:
370
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
371
  elif past_key_value.shape[2] != key_value_states.shape[1]:
372
  # checking that the `sequence_length` of the `past_key_value` is the same as
373
  # the provided `key_value_states` to support prefix tuning
 
394
 
395
 
396
  # compute scores
397
+ if self.config.is_mimo:
398
+ scores = torch.matmul(query_states, key_states.transpose(4, 3))
 
 
399
  else:
400
+ scores = torch.matmul(query_states, key_states.transpose(3, 2)) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
 
 
401
 
402
 
403
 
 
405
 
406
  if position_bias is None:
407
  if not self.has_relative_attention_bias:
408
+ if self.config.is_mimo:
409
+ position_bias = torch.zeros((1,1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype)
 
 
 
410
  else:
411
+ position_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype)
 
 
412
  if self.gradient_checkpointing and self.training:
413
  position_bias.requires_grad = True
414
  else:
415
+ position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
 
 
 
 
416
 
417
  # if key and values are already calculated
418
  # we want only the last query position bias
419
  if past_key_value is not None:
420
+ if self.config.is_mimo:
421
+ position_bias = position_bias[:, :, :, -hidden_states.size(2) :, :]
422
+ else:
423
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
424
 
425
  if mask is not None:
426
  position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
 
434
  else:
435
  position_bias_masked = position_bias
436
 
 
437
  scores += position_bias_masked
438
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (batch_size, n_heads, seq_length, key_length)
439
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # (batch_size, n_heads, seq_length, key_length)
 
 
 
 
440
 
441
  # Mask heads if we want to
442
  if layer_head_mask is not None:
443
  attn_weights = attn_weights * layer_head_mask
444
 
445
 
446
+ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
 
 
 
447
  attn_output = self.o(attn_output)
448
 
449
 
 
509
  query_length=None,
510
  output_attentions=False,
511
  ):
 
512
  normed_hidden_states = self.layer_norm(hidden_states)
513
  attention_output = self.EncDecAttention(
514
  normed_hidden_states,
 
537
 
538
  self.layer.append(T5LayerFF(config))
539
 
540
+ self.config = config
541
+
542
  def forward(
543
  self,
544
  hidden_states,
 
597
  # the actual query length is unknown for cross attention
598
  # if using past key value states. Need to inject it here
599
  if present_key_value_state is not None:
600
+ if self.config.is_mimo:
601
+ query_length = present_key_value_state[0].shape[3]
602
+ else:
603
+ query_length = present_key_value_state[0].shape[2]
604
  else:
605
  query_length = None
606
 
 
872
  self.embed_tokens = self.embed_tokens.to(self.first_device)
873
  use_cache = use_cache if use_cache is not None else self.config.use_cache
874
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
875
+ output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
 
 
876
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
877
 
878
  if input_ids is not None and inputs_embeds is not None:
879
  err_msg_prefix = "decoder_" if self.is_decoder else ""
880
+ raise ValueError(f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time")
 
 
881
  elif input_ids is not None:
882
  input_shape = input_ids.size()
 
883
  elif inputs_embeds is not None:
884
  input_shape = inputs_embeds.size()[:-1]
885
  else:
 
891
  raise ValueError("You have to initialize the model with valid token embeddings")
892
  inputs_embeds = self.embed_tokens(input_ids)
893
 
894
+ if self.config.is_mimo:
895
  batch_size, multivar_seqs ,seq_length = input_shape
896
  else:
897
  batch_size, seq_length = input_shape
898
 
899
  # required mask seq length can be calculated via length of past
900
+ if self.config.is_mimo:
901
+ mask_seq_length = past_key_values[0][0].shape[3] + seq_length if past_key_values is not None else seq_length
902
+ else:
903
+ mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
904
 
905
  if use_cache is True:
906
  if not self.is_decoder:
 
911
  past_key_values = [None] * len(self.block)
912
 
913
  if attention_mask is None:
914
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
 
 
 
915
 
916
 
917
 
918
  # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
919
  # ourselves in which case we just need to make it broadcastable to all heads.
920
+ if self.config.is_mimo:
921
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, (input_shape[0], input_shape[2]))
922
+ extended_attention_mask = extended_attention_mask.unsqueeze(1)
923
  else:
924
  extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
 
 
 
 
925
 
926
  # If a 2D or 3D attention mask is provided for the cross-attention
927
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
928
  if self.is_decoder and encoder_hidden_states is not None:
929
+ if self.config.is_mimo:
 
 
 
930
  encoder_batch_size, multivar_dem, encoder_sequence_length, _ = encoder_hidden_states.size()
931
+ else:
932
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
933
 
934
  encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
935
  if encoder_attention_mask is None:
936
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long)
937
+ if self.config.is_mimo:
 
 
938
  encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
939
+ encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(1)
940
  else:
941
  encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
 
 
 
942
 
943
  else:
944
  encoder_extended_attention_mask = None
 
947
 
948
  if self.gradient_checkpointing and self.training:
949
  if use_cache:
950
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
 
 
951
  use_cache = False
952
 
953
  # Prepare head mask if needed
 
1425
  >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
1426
  >>> # studies have shown that owning a dog is good for you.
1427
  ```"""
1428
+
1429
+
1430
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1431
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1432
 
 
1435
  if self.config.num_layers == self.config.num_decoder_layers:
1436
  decoder_head_mask = head_mask
1437
 
1438
+
1439
+
1440
  # Encode if needed (training, first prediction pass)
1441
  if encoder_outputs is None:
1442
  # Convert encoder inputs in embeddings if needed
 
1476
  if decoder_attention_mask is not None:
1477
  decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1478
 
1479
+ if hidden_states is not None and decoder_input_ids is not None:
1480
+ if len(hidden_states.shape) == 4:
1481
+ batch_size, multivar_seqs, seq_length , model_dim = hidden_states.shape
1482
+ if len(decoder_input_ids.shape) == 2:
1483
+ decoder_input_ids = decoder_input_ids.unsqueeze(1).repeat(1, multivar_seqs, 1)
1484
+
1485
+
1486
+
1487
+
1488
  # Decode
1489
  decoder_outputs = self.decoder(
1490
  input_ids=decoder_input_ids,
 
1503
 
1504
  sequence_output = decoder_outputs[0]
1505
 
1506
+
1507
  if use_conv:
1508
  sequence_output = self.conv_block(sequence_output)
1509
 
 
1534
  if not return_dict:
1535
  output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
1536
  return ((loss,) + output) if loss is not None else output
1537
+
1538
+
1539
+
1540
 
1541
+ seq2seqlmoutput = Seq2SeqLMOutput(
1542
  loss=loss,
1543
  logits=lm_logits,
1544
  past_key_values=decoder_outputs.past_key_values,
 
1549
  encoder_hidden_states=encoder_outputs.hidden_states,
1550
  encoder_attentions=encoder_outputs.attentions,
1551
  )
1552
+ return seq2seqlmoutput
1553
 
1554
  def prepare_inputs_for_generation(
1555
  self,
 
1630
 
1631
  def __init__(self, config: T5MIMOConfig):
1632
  super().__init__(config)
1633
+
1634
  self.shared = nn.Embedding(config.vocab_size, config.d_model)
1635
 
1636
  encoder_config = copy.deepcopy(config)