Upload model
Browse files- config.json +1 -0
- configuration_t5mimo.py +2 -0
- modeling_t5mimo.py +86 -95
config.json
CHANGED
@@ -18,6 +18,7 @@
|
|
18 |
"initializer_factor": 0.05,
|
19 |
"is_encoder_decoder": true,
|
20 |
"is_gated_act": false,
|
|
|
21 |
"layer_norm_epsilon": 1e-06,
|
22 |
"model_type": "t5mimo",
|
23 |
"num_decoder_layers": 4,
|
|
|
18 |
"initializer_factor": 0.05,
|
19 |
"is_encoder_decoder": true,
|
20 |
"is_gated_act": false,
|
21 |
+
"is_mimo": true,
|
22 |
"layer_norm_epsilon": 1e-06,
|
23 |
"model_type": "t5mimo",
|
24 |
"num_decoder_layers": 4,
|
configuration_t5mimo.py
CHANGED
@@ -81,6 +81,7 @@ class T5MIMOConfig(PretrainedConfig):
|
|
81 |
classifier_dropout=0.0,
|
82 |
num_seqs=3,
|
83 |
num_filters=64,
|
|
|
84 |
**kwargs,
|
85 |
):
|
86 |
self.vocab_size = vocab_size
|
@@ -102,6 +103,7 @@ class T5MIMOConfig(PretrainedConfig):
|
|
102 |
self.use_cache = use_cache
|
103 |
self.num_seqs = num_seqs
|
104 |
self.num_filters = num_filters
|
|
|
105 |
|
106 |
act_info = self.feed_forward_proj.split("-")
|
107 |
self.dense_act_fn = act_info[-1]
|
|
|
81 |
classifier_dropout=0.0,
|
82 |
num_seqs=3,
|
83 |
num_filters=64,
|
84 |
+
is_mimo=True,
|
85 |
**kwargs,
|
86 |
):
|
87 |
self.vocab_size = vocab_size
|
|
|
103 |
self.use_cache = use_cache
|
104 |
self.num_seqs = num_seqs
|
105 |
self.num_filters = num_filters
|
106 |
+
self.is_mimo = is_mimo
|
107 |
|
108 |
act_info = self.feed_forward_proj.split("-")
|
109 |
self.dense_act_fn = act_info[-1]
|
modeling_t5mimo.py
CHANGED
@@ -198,8 +198,9 @@ class T5Attention(nn.Module):
|
|
198 |
self.d_model = config.d_model
|
199 |
self.key_value_proj_dim = config.d_kv
|
200 |
self.n_heads = config.num_heads
|
201 |
-
self.dropout = config.dropout_rate
|
202 |
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
|
|
|
|
203 |
|
204 |
# Mesh TensorFlow initialization to avoid scaling before softmax
|
205 |
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
@@ -276,7 +277,7 @@ class T5Attention(nn.Module):
|
|
276 |
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
277 |
return relative_buckets
|
278 |
|
279 |
-
def compute_bias(self, query_length, key_length,
|
280 |
"""Compute binned relative position bias"""
|
281 |
if device is None:
|
282 |
device = self.relative_attention_bias.weight.device
|
@@ -291,9 +292,8 @@ class T5Attention(nn.Module):
|
|
291 |
)
|
292 |
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
293 |
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
294 |
-
if
|
295 |
-
values = values.
|
296 |
-
|
297 |
return values
|
298 |
|
299 |
def forward(
|
@@ -314,42 +314,41 @@ class T5Attention(nn.Module):
|
|
314 |
# Input is (batch_size, seq_length, dim)
|
315 |
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
|
316 |
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
|
317 |
-
if
|
318 |
-
batch_size, seq_length = hidden_states.shape[:
|
319 |
else:
|
320 |
-
batch_size, seq_length = hidden_states.shape[
|
321 |
-
multivar_dim = hidden_states.shape[1]
|
322 |
real_seq_length = seq_length
|
323 |
-
|
324 |
if past_key_value is not None:
|
325 |
if len(past_key_value) != 2:
|
326 |
-
raise ValueError(
|
327 |
-
|
328 |
-
|
329 |
-
|
|
|
330 |
|
331 |
-
if
|
332 |
-
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
|
333 |
-
else:
|
334 |
key_length = real_seq_length if key_value_states is None else key_value_states.shape[2]
|
|
|
|
|
|
|
335 |
|
336 |
|
337 |
def shape(states):
|
338 |
"""projection"""
|
339 |
-
|
340 |
-
# states: torch.Size([3, 6, 16, 512]) -> query_states: torch.Size([3, 6, 8 , 16, 64])
|
341 |
-
if len(states.shape) == 3:
|
342 |
-
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
343 |
-
else:
|
344 |
return states.view(batch_size, multivar_dim, -1, self.n_heads, self.key_value_proj_dim).transpose(2, 3)
|
|
|
|
|
345 |
|
346 |
|
347 |
def unshape(states):
|
348 |
"""reshape"""
|
349 |
-
if
|
350 |
-
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
|
351 |
-
else:
|
352 |
return states.transpose(2, 3).contiguous().view(batch_size, multivar_dim, -1, self.inner_dim)
|
|
|
|
|
353 |
|
354 |
def project(hidden_states, proj_layer, key_value_states, past_key_value):
|
355 |
"""projects hidden states correctly to key/query states"""
|
@@ -361,12 +360,14 @@ class T5Attention(nn.Module):
|
|
361 |
# cross-attn
|
362 |
# (batch_size, n_heads, seq_length, dim_per_head)
|
363 |
hidden_states = shape(proj_layer(key_value_states))
|
364 |
-
|
365 |
if past_key_value is not None:
|
366 |
if key_value_states is None:
|
367 |
# self-attn
|
368 |
# (batch_size, n_heads, key_length, dim_per_head)
|
369 |
-
|
|
|
|
|
|
|
370 |
elif past_key_value.shape[2] != key_value_states.shape[1]:
|
371 |
# checking that the `sequence_length` of the `past_key_value` is the same as
|
372 |
# the provided `key_value_states` to support prefix tuning
|
@@ -393,14 +394,10 @@ class T5Attention(nn.Module):
|
|
393 |
|
394 |
|
395 |
# compute scores
|
396 |
-
if
|
397 |
-
scores = torch.matmul(
|
398 |
-
query_states, key_states.transpose(3, 2)
|
399 |
-
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
|
400 |
else:
|
401 |
-
scores = torch.matmul(
|
402 |
-
query_states, key_states.transpose(4, 3)
|
403 |
-
)
|
404 |
|
405 |
|
406 |
|
@@ -408,28 +405,22 @@ class T5Attention(nn.Module):
|
|
408 |
|
409 |
if position_bias is None:
|
410 |
if not self.has_relative_attention_bias:
|
411 |
-
|
412 |
-
|
413 |
-
position_bias = torch.zeros(
|
414 |
-
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
|
415 |
-
)
|
416 |
else:
|
417 |
-
position_bias = torch.zeros(
|
418 |
-
(1,multivar_dim, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
|
419 |
-
)
|
420 |
if self.gradient_checkpointing and self.training:
|
421 |
position_bias.requires_grad = True
|
422 |
else:
|
423 |
-
|
424 |
-
if len(hidden_states.shape) == 3:
|
425 |
-
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
|
426 |
-
else:
|
427 |
-
position_bias = self.compute_bias(real_seq_length, key_length,multivar_dim=multivar_dim, device=scores.device)
|
428 |
|
429 |
# if key and values are already calculated
|
430 |
# we want only the last query position bias
|
431 |
if past_key_value is not None:
|
432 |
-
|
|
|
|
|
|
|
433 |
|
434 |
if mask is not None:
|
435 |
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
|
@@ -443,24 +434,16 @@ class T5Attention(nn.Module):
|
|
443 |
else:
|
444 |
position_bias_masked = position_bias
|
445 |
|
446 |
-
|
447 |
scores += position_bias_masked
|
448 |
-
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
|
449 |
-
|
450 |
-
) # (batch_size, n_heads, seq_length, key_length)
|
451 |
-
attn_weights = nn.functional.dropout(
|
452 |
-
attn_weights, p=self.dropout, training=self.training
|
453 |
-
) # (batch_size, n_heads, seq_length, key_length)
|
454 |
|
455 |
# Mask heads if we want to
|
456 |
if layer_head_mask is not None:
|
457 |
attn_weights = attn_weights * layer_head_mask
|
458 |
|
459 |
|
460 |
-
|
461 |
-
attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
|
462 |
-
else:
|
463 |
-
attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, multivar_dim, seq_length, dim)
|
464 |
attn_output = self.o(attn_output)
|
465 |
|
466 |
|
@@ -526,7 +509,6 @@ class T5LayerCrossAttention(nn.Module):
|
|
526 |
query_length=None,
|
527 |
output_attentions=False,
|
528 |
):
|
529 |
-
|
530 |
normed_hidden_states = self.layer_norm(hidden_states)
|
531 |
attention_output = self.EncDecAttention(
|
532 |
normed_hidden_states,
|
@@ -555,6 +537,8 @@ class T5Block(nn.Module):
|
|
555 |
|
556 |
self.layer.append(T5LayerFF(config))
|
557 |
|
|
|
|
|
558 |
def forward(
|
559 |
self,
|
560 |
hidden_states,
|
@@ -613,7 +597,10 @@ class T5Block(nn.Module):
|
|
613 |
# the actual query length is unknown for cross attention
|
614 |
# if using past key value states. Need to inject it here
|
615 |
if present_key_value_state is not None:
|
616 |
-
|
|
|
|
|
|
|
617 |
else:
|
618 |
query_length = None
|
619 |
|
@@ -885,19 +872,14 @@ class T5Stack(T5PreTrainedModel):
|
|
885 |
self.embed_tokens = self.embed_tokens.to(self.first_device)
|
886 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
887 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
888 |
-
output_hidden_states = (
|
889 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
890 |
-
)
|
891 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
892 |
|
893 |
if input_ids is not None and inputs_embeds is not None:
|
894 |
err_msg_prefix = "decoder_" if self.is_decoder else ""
|
895 |
-
raise ValueError(
|
896 |
-
f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
|
897 |
-
)
|
898 |
elif input_ids is not None:
|
899 |
input_shape = input_ids.size()
|
900 |
-
# input_ids = input_ids.view(-1, input_shape[-1])
|
901 |
elif inputs_embeds is not None:
|
902 |
input_shape = inputs_embeds.size()[:-1]
|
903 |
else:
|
@@ -909,13 +891,16 @@ class T5Stack(T5PreTrainedModel):
|
|
909 |
raise ValueError("You have to initialize the model with valid token embeddings")
|
910 |
inputs_embeds = self.embed_tokens(input_ids)
|
911 |
|
912 |
-
if
|
913 |
batch_size, multivar_seqs ,seq_length = input_shape
|
914 |
else:
|
915 |
batch_size, seq_length = input_shape
|
916 |
|
917 |
# required mask seq length can be calculated via length of past
|
918 |
-
|
|
|
|
|
|
|
919 |
|
920 |
if use_cache is True:
|
921 |
if not self.is_decoder:
|
@@ -926,45 +911,34 @@ class T5Stack(T5PreTrainedModel):
|
|
926 |
past_key_values = [None] * len(self.block)
|
927 |
|
928 |
if attention_mask is None:
|
929 |
-
|
930 |
-
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
931 |
-
else:
|
932 |
-
attention_mask = torch.ones(batch_size, multivar_seqs, mask_seq_length, device=inputs_embeds.device)
|
933 |
|
934 |
|
935 |
|
936 |
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
937 |
# ourselves in which case we just need to make it broadcastable to all heads.
|
938 |
-
if
|
939 |
-
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
|
|
940 |
else:
|
941 |
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
942 |
-
# permute from [batch_size, 1, multivar_seqs, seq_length] to [batch_size, multivar_seqs, 1, seq_length]
|
943 |
-
extended_attention_mask = extended_attention_mask.permute(0, 2, 1, 3)
|
944 |
-
# Now make it [batch_size, multivar_seqs, 1, 1, seq_length]
|
945 |
-
extended_attention_mask = extended_attention_mask.unsqueeze(3)
|
946 |
|
947 |
# If a 2D or 3D attention mask is provided for the cross-attention
|
948 |
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
949 |
if self.is_decoder and encoder_hidden_states is not None:
|
950 |
-
|
951 |
-
if len(encoder_hidden_states.size()) == 3 :
|
952 |
-
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
953 |
-
else:
|
954 |
encoder_batch_size, multivar_dem, encoder_sequence_length, _ = encoder_hidden_states.size()
|
|
|
|
|
955 |
|
956 |
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
957 |
if encoder_attention_mask is None:
|
958 |
-
encoder_attention_mask = torch.ones(
|
959 |
-
|
960 |
-
)
|
961 |
-
if len(input_shape) == 2:
|
962 |
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
|
|
963 |
else:
|
964 |
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
965 |
-
multivar_dim = extended_attention_mask.shape[1]
|
966 |
-
encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(1)
|
967 |
-
encoder_extended_attention_mask = encoder_extended_attention_mask.permute(0, 3, 1, 2, 4)
|
968 |
|
969 |
else:
|
970 |
encoder_extended_attention_mask = None
|
@@ -973,9 +947,7 @@ class T5Stack(T5PreTrainedModel):
|
|
973 |
|
974 |
if self.gradient_checkpointing and self.training:
|
975 |
if use_cache:
|
976 |
-
logger.warning_once(
|
977 |
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
978 |
-
)
|
979 |
use_cache = False
|
980 |
|
981 |
# Prepare head mask if needed
|
@@ -1453,6 +1425,8 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
|
|
1453 |
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
1454 |
>>> # studies have shown that owning a dog is good for you.
|
1455 |
```"""
|
|
|
|
|
1456 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1457 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1458 |
|
@@ -1461,6 +1435,8 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
|
|
1461 |
if self.config.num_layers == self.config.num_decoder_layers:
|
1462 |
decoder_head_mask = head_mask
|
1463 |
|
|
|
|
|
1464 |
# Encode if needed (training, first prediction pass)
|
1465 |
if encoder_outputs is None:
|
1466 |
# Convert encoder inputs in embeddings if needed
|
@@ -1500,6 +1476,15 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
|
|
1500 |
if decoder_attention_mask is not None:
|
1501 |
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
|
1502 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1503 |
# Decode
|
1504 |
decoder_outputs = self.decoder(
|
1505 |
input_ids=decoder_input_ids,
|
@@ -1518,6 +1503,7 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
|
|
1518 |
|
1519 |
sequence_output = decoder_outputs[0]
|
1520 |
|
|
|
1521 |
if use_conv:
|
1522 |
sequence_output = self.conv_block(sequence_output)
|
1523 |
|
@@ -1548,8 +1534,11 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
|
|
1548 |
if not return_dict:
|
1549 |
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
1550 |
return ((loss,) + output) if loss is not None else output
|
|
|
|
|
|
|
1551 |
|
1552 |
-
|
1553 |
loss=loss,
|
1554 |
logits=lm_logits,
|
1555 |
past_key_values=decoder_outputs.past_key_values,
|
@@ -1560,6 +1549,7 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
|
|
1560 |
encoder_hidden_states=encoder_outputs.hidden_states,
|
1561 |
encoder_attentions=encoder_outputs.attentions,
|
1562 |
)
|
|
|
1563 |
|
1564 |
def prepare_inputs_for_generation(
|
1565 |
self,
|
@@ -1640,6 +1630,7 @@ class T5MIMOEncoderModel(T5PreTrainedModel):
|
|
1640 |
|
1641 |
def __init__(self, config: T5MIMOConfig):
|
1642 |
super().__init__(config)
|
|
|
1643 |
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
1644 |
|
1645 |
encoder_config = copy.deepcopy(config)
|
|
|
198 |
self.d_model = config.d_model
|
199 |
self.key_value_proj_dim = config.d_kv
|
200 |
self.n_heads = config.num_heads
|
|
|
201 |
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
202 |
+
self.dropout = config.dropout_rate
|
203 |
+
self.config = config
|
204 |
|
205 |
# Mesh TensorFlow initialization to avoid scaling before softmax
|
206 |
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
|
|
277 |
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
278 |
return relative_buckets
|
279 |
|
280 |
+
def compute_bias(self, query_length, key_length, device=None):
|
281 |
"""Compute binned relative position bias"""
|
282 |
if device is None:
|
283 |
device = self.relative_attention_bias.weight.device
|
|
|
292 |
)
|
293 |
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
294 |
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
295 |
+
if self.config.is_mimo:
|
296 |
+
values = values.unsqueeze(0)# shape (1, 1, num_heads, query_length, key_length)
|
|
|
297 |
return values
|
298 |
|
299 |
def forward(
|
|
|
314 |
# Input is (batch_size, seq_length, dim)
|
315 |
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
|
316 |
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
|
317 |
+
if self.config.is_mimo:
|
318 |
+
batch_size, multivar_dim, seq_length = hidden_states.shape[:3]
|
319 |
else:
|
320 |
+
batch_size, seq_length = hidden_states.shape[:2]
|
|
|
321 |
real_seq_length = seq_length
|
322 |
+
|
323 |
if past_key_value is not None:
|
324 |
if len(past_key_value) != 2:
|
325 |
+
raise ValueError(f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states")
|
326 |
+
if self.config.is_mimo:
|
327 |
+
real_seq_length += past_key_value[0].shape[3] if query_length is None else query_length
|
328 |
+
else:
|
329 |
+
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
|
330 |
|
331 |
+
if self.config.is_mimo:
|
|
|
|
|
332 |
key_length = real_seq_length if key_value_states is None else key_value_states.shape[2]
|
333 |
+
else:
|
334 |
+
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
|
335 |
+
|
336 |
|
337 |
|
338 |
def shape(states):
|
339 |
"""projection"""
|
340 |
+
if self.config.is_mimo:
|
|
|
|
|
|
|
|
|
341 |
return states.view(batch_size, multivar_dim, -1, self.n_heads, self.key_value_proj_dim).transpose(2, 3)
|
342 |
+
else:
|
343 |
+
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
344 |
|
345 |
|
346 |
def unshape(states):
|
347 |
"""reshape"""
|
348 |
+
if self.config.is_mimo:
|
|
|
|
|
349 |
return states.transpose(2, 3).contiguous().view(batch_size, multivar_dim, -1, self.inner_dim)
|
350 |
+
else:
|
351 |
+
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
|
352 |
|
353 |
def project(hidden_states, proj_layer, key_value_states, past_key_value):
|
354 |
"""projects hidden states correctly to key/query states"""
|
|
|
360 |
# cross-attn
|
361 |
# (batch_size, n_heads, seq_length, dim_per_head)
|
362 |
hidden_states = shape(proj_layer(key_value_states))
|
|
|
363 |
if past_key_value is not None:
|
364 |
if key_value_states is None:
|
365 |
# self-attn
|
366 |
# (batch_size, n_heads, key_length, dim_per_head)
|
367 |
+
if self.config.is_mimo:
|
368 |
+
hidden_states = torch.cat([past_key_value, hidden_states], dim=3)
|
369 |
+
else:
|
370 |
+
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
|
371 |
elif past_key_value.shape[2] != key_value_states.shape[1]:
|
372 |
# checking that the `sequence_length` of the `past_key_value` is the same as
|
373 |
# the provided `key_value_states` to support prefix tuning
|
|
|
394 |
|
395 |
|
396 |
# compute scores
|
397 |
+
if self.config.is_mimo:
|
398 |
+
scores = torch.matmul(query_states, key_states.transpose(4, 3))
|
|
|
|
|
399 |
else:
|
400 |
+
scores = torch.matmul(query_states, key_states.transpose(3, 2)) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
|
|
|
|
|
401 |
|
402 |
|
403 |
|
|
|
405 |
|
406 |
if position_bias is None:
|
407 |
if not self.has_relative_attention_bias:
|
408 |
+
if self.config.is_mimo:
|
409 |
+
position_bias = torch.zeros((1,1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype)
|
|
|
|
|
|
|
410 |
else:
|
411 |
+
position_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype)
|
|
|
|
|
412 |
if self.gradient_checkpointing and self.training:
|
413 |
position_bias.requires_grad = True
|
414 |
else:
|
415 |
+
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
|
|
|
|
|
|
|
|
|
416 |
|
417 |
# if key and values are already calculated
|
418 |
# we want only the last query position bias
|
419 |
if past_key_value is not None:
|
420 |
+
if self.config.is_mimo:
|
421 |
+
position_bias = position_bias[:, :, :, -hidden_states.size(2) :, :]
|
422 |
+
else:
|
423 |
+
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
|
424 |
|
425 |
if mask is not None:
|
426 |
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
|
|
|
434 |
else:
|
435 |
position_bias_masked = position_bias
|
436 |
|
|
|
437 |
scores += position_bias_masked
|
438 |
+
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (batch_size, n_heads, seq_length, key_length)
|
439 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # (batch_size, n_heads, seq_length, key_length)
|
|
|
|
|
|
|
|
|
440 |
|
441 |
# Mask heads if we want to
|
442 |
if layer_head_mask is not None:
|
443 |
attn_weights = attn_weights * layer_head_mask
|
444 |
|
445 |
|
446 |
+
attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
|
|
|
|
|
|
|
447 |
attn_output = self.o(attn_output)
|
448 |
|
449 |
|
|
|
509 |
query_length=None,
|
510 |
output_attentions=False,
|
511 |
):
|
|
|
512 |
normed_hidden_states = self.layer_norm(hidden_states)
|
513 |
attention_output = self.EncDecAttention(
|
514 |
normed_hidden_states,
|
|
|
537 |
|
538 |
self.layer.append(T5LayerFF(config))
|
539 |
|
540 |
+
self.config = config
|
541 |
+
|
542 |
def forward(
|
543 |
self,
|
544 |
hidden_states,
|
|
|
597 |
# the actual query length is unknown for cross attention
|
598 |
# if using past key value states. Need to inject it here
|
599 |
if present_key_value_state is not None:
|
600 |
+
if self.config.is_mimo:
|
601 |
+
query_length = present_key_value_state[0].shape[3]
|
602 |
+
else:
|
603 |
+
query_length = present_key_value_state[0].shape[2]
|
604 |
else:
|
605 |
query_length = None
|
606 |
|
|
|
872 |
self.embed_tokens = self.embed_tokens.to(self.first_device)
|
873 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
874 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
875 |
+
output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
|
|
|
|
|
876 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
877 |
|
878 |
if input_ids is not None and inputs_embeds is not None:
|
879 |
err_msg_prefix = "decoder_" if self.is_decoder else ""
|
880 |
+
raise ValueError(f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time")
|
|
|
|
|
881 |
elif input_ids is not None:
|
882 |
input_shape = input_ids.size()
|
|
|
883 |
elif inputs_embeds is not None:
|
884 |
input_shape = inputs_embeds.size()[:-1]
|
885 |
else:
|
|
|
891 |
raise ValueError("You have to initialize the model with valid token embeddings")
|
892 |
inputs_embeds = self.embed_tokens(input_ids)
|
893 |
|
894 |
+
if self.config.is_mimo:
|
895 |
batch_size, multivar_seqs ,seq_length = input_shape
|
896 |
else:
|
897 |
batch_size, seq_length = input_shape
|
898 |
|
899 |
# required mask seq length can be calculated via length of past
|
900 |
+
if self.config.is_mimo:
|
901 |
+
mask_seq_length = past_key_values[0][0].shape[3] + seq_length if past_key_values is not None else seq_length
|
902 |
+
else:
|
903 |
+
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
|
904 |
|
905 |
if use_cache is True:
|
906 |
if not self.is_decoder:
|
|
|
911 |
past_key_values = [None] * len(self.block)
|
912 |
|
913 |
if attention_mask is None:
|
914 |
+
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
|
|
|
|
|
|
915 |
|
916 |
|
917 |
|
918 |
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
919 |
# ourselves in which case we just need to make it broadcastable to all heads.
|
920 |
+
if self.config.is_mimo:
|
921 |
+
extended_attention_mask = self.get_extended_attention_mask(attention_mask, (input_shape[0], input_shape[2]))
|
922 |
+
extended_attention_mask = extended_attention_mask.unsqueeze(1)
|
923 |
else:
|
924 |
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
|
|
|
|
|
|
|
|
925 |
|
926 |
# If a 2D or 3D attention mask is provided for the cross-attention
|
927 |
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
928 |
if self.is_decoder and encoder_hidden_states is not None:
|
929 |
+
if self.config.is_mimo:
|
|
|
|
|
|
|
930 |
encoder_batch_size, multivar_dem, encoder_sequence_length, _ = encoder_hidden_states.size()
|
931 |
+
else:
|
932 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
933 |
|
934 |
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
935 |
if encoder_attention_mask is None:
|
936 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long)
|
937 |
+
if self.config.is_mimo:
|
|
|
|
|
938 |
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
939 |
+
encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(1)
|
940 |
else:
|
941 |
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
|
|
|
|
|
|
942 |
|
943 |
else:
|
944 |
encoder_extended_attention_mask = None
|
|
|
947 |
|
948 |
if self.gradient_checkpointing and self.training:
|
949 |
if use_cache:
|
950 |
+
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
|
|
|
|
951 |
use_cache = False
|
952 |
|
953 |
# Prepare head mask if needed
|
|
|
1425 |
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
1426 |
>>> # studies have shown that owning a dog is good for you.
|
1427 |
```"""
|
1428 |
+
|
1429 |
+
|
1430 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1431 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1432 |
|
|
|
1435 |
if self.config.num_layers == self.config.num_decoder_layers:
|
1436 |
decoder_head_mask = head_mask
|
1437 |
|
1438 |
+
|
1439 |
+
|
1440 |
# Encode if needed (training, first prediction pass)
|
1441 |
if encoder_outputs is None:
|
1442 |
# Convert encoder inputs in embeddings if needed
|
|
|
1476 |
if decoder_attention_mask is not None:
|
1477 |
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
|
1478 |
|
1479 |
+
if hidden_states is not None and decoder_input_ids is not None:
|
1480 |
+
if len(hidden_states.shape) == 4:
|
1481 |
+
batch_size, multivar_seqs, seq_length , model_dim = hidden_states.shape
|
1482 |
+
if len(decoder_input_ids.shape) == 2:
|
1483 |
+
decoder_input_ids = decoder_input_ids.unsqueeze(1).repeat(1, multivar_seqs, 1)
|
1484 |
+
|
1485 |
+
|
1486 |
+
|
1487 |
+
|
1488 |
# Decode
|
1489 |
decoder_outputs = self.decoder(
|
1490 |
input_ids=decoder_input_ids,
|
|
|
1503 |
|
1504 |
sequence_output = decoder_outputs[0]
|
1505 |
|
1506 |
+
|
1507 |
if use_conv:
|
1508 |
sequence_output = self.conv_block(sequence_output)
|
1509 |
|
|
|
1534 |
if not return_dict:
|
1535 |
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
1536 |
return ((loss,) + output) if loss is not None else output
|
1537 |
+
|
1538 |
+
|
1539 |
+
|
1540 |
|
1541 |
+
seq2seqlmoutput = Seq2SeqLMOutput(
|
1542 |
loss=loss,
|
1543 |
logits=lm_logits,
|
1544 |
past_key_values=decoder_outputs.past_key_values,
|
|
|
1549 |
encoder_hidden_states=encoder_outputs.hidden_states,
|
1550 |
encoder_attentions=encoder_outputs.attentions,
|
1551 |
)
|
1552 |
+
return seq2seqlmoutput
|
1553 |
|
1554 |
def prepare_inputs_for_generation(
|
1555 |
self,
|
|
|
1630 |
|
1631 |
def __init__(self, config: T5MIMOConfig):
|
1632 |
super().__init__(config)
|
1633 |
+
|
1634 |
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
1635 |
|
1636 |
encoder_config = copy.deepcopy(config)
|