x54-729
commited on
Commit
•
4881d44
1
Parent(s):
0edf336
support flash attn 2
Browse files- configuration_internlm.py +4 -0
- modeling_internlm.py +186 -18
configuration_internlm.py
CHANGED
@@ -91,6 +91,7 @@ class InternLMConfig(PretrainedConfig):
|
|
91 |
tie_word_embeddings=False,
|
92 |
bias=True,
|
93 |
rotary={"base": 10000, "type": "dynamic"}, # pylint: disable=W0102
|
|
|
94 |
**kwargs,
|
95 |
):
|
96 |
self.vocab_size = vocab_size
|
@@ -105,6 +106,9 @@ class InternLMConfig(PretrainedConfig):
|
|
105 |
self.use_cache = use_cache
|
106 |
self.bias = bias
|
107 |
self.rotary = rotary
|
|
|
|
|
|
|
108 |
super().__init__(
|
109 |
pad_token_id=pad_token_id,
|
110 |
bos_token_id=bos_token_id,
|
|
|
91 |
tie_word_embeddings=False,
|
92 |
bias=True,
|
93 |
rotary={"base": 10000, "type": "dynamic"}, # pylint: disable=W0102
|
94 |
+
attn_implementation="eager",
|
95 |
**kwargs,
|
96 |
):
|
97 |
self.vocab_size = vocab_size
|
|
|
106 |
self.use_cache = use_cache
|
107 |
self.bias = bias
|
108 |
self.rotary = rotary
|
109 |
+
self.attn_implementation = attn_implementation
|
110 |
+
if self.attn_implementation is None:
|
111 |
+
self.attn_implementation = "eager"
|
112 |
super().__init__(
|
113 |
pad_token_id=pad_token_id,
|
114 |
bos_token_id=bos_token_id,
|
modeling_internlm.py
CHANGED
@@ -1,10 +1,6 @@
|
|
1 |
-
#
|
2 |
-
# Copyright (c) InternLM. All rights reserved.
|
3 |
#
|
4 |
-
# This code is based on
|
5 |
-
# and OPT implementations in this library. It has been modified from its
|
6 |
-
# original forms to accommodate minor architectural differences compared
|
7 |
-
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
#
|
9 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
# you may not use this file except in compliance with the License.
|
@@ -52,6 +48,17 @@ logger = logging.get_logger(__name__)
|
|
52 |
|
53 |
_CONFIG_FOR_DOC = "InternLMConfig"
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
57 |
def _make_causal_mask(
|
@@ -85,7 +92,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|
85 |
|
86 |
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
87 |
|
88 |
-
|
89 |
class InternLMRMSNorm(nn.Module):
|
90 |
"""RMSNorm implemention."""
|
91 |
|
@@ -228,8 +234,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
|
228 |
k_sin = sin[position_ids].unsqueeze(1).expand(k.shape)
|
229 |
k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
|
230 |
else:
|
231 |
-
cos = cos[position_ids].unsqueeze(1)
|
232 |
-
sin = sin[position_ids].unsqueeze(1)
|
233 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
234 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
235 |
return q_embed, k_embed
|
@@ -273,6 +279,7 @@ class InternLMAttention(nn.Module):
|
|
273 |
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
274 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
|
275 |
self.rotary_emb = self._init_rope()
|
|
|
276 |
|
277 |
def _init_rope(self):
|
278 |
if self.config.rotary["type"] == "origin":
|
@@ -356,13 +363,167 @@ class InternLMAttention(nn.Module):
|
|
356 |
attn_weights = None
|
357 |
|
358 |
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
|
361 |
class InternLMDecoderLayer(nn.Module):
|
362 |
def __init__(self, config: InternLMConfig):
|
363 |
super().__init__()
|
364 |
self.hidden_size = config.hidden_size
|
365 |
-
|
|
|
|
|
366 |
self.mlp = InternLMMLP(
|
367 |
hidden_size=self.hidden_size,
|
368 |
intermediate_size=config.intermediate_size,
|
@@ -539,8 +700,10 @@ class InternLMModel(InternLMPreTrainedModel):
|
|
539 |
super().__init__(config)
|
540 |
self.padding_idx = config.pad_token_id
|
541 |
self.vocab_size = config.vocab_size
|
|
|
542 |
|
543 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
|
|
544 |
self.layers = nn.ModuleList([InternLMDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
545 |
self.norm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
546 |
|
@@ -627,14 +790,16 @@ class InternLMModel(InternLMPreTrainedModel):
|
|
627 |
|
628 |
if inputs_embeds is None:
|
629 |
inputs_embeds = self.embed_tokens(input_ids)
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
|
|
|
|
|
|
|
|
|
|
634 |
)
|
635 |
-
attention_mask = self._prepare_decoder_attention_mask(
|
636 |
-
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
637 |
-
)
|
638 |
|
639 |
hidden_states = inputs_embeds
|
640 |
|
@@ -759,6 +924,7 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|
759 |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
760 |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
761 |
Returns:
|
|
|
762 |
Example:
|
763 |
```python
|
764 |
>>> from transformers import AutoTokenizer, InternLMForCausalLM
|
@@ -770,7 +936,9 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|
770 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
771 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
772 |
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
773 |
-
```
|
|
|
|
|
774 |
|
775 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
776 |
output_hidden_states = (
|
|
|
1 |
+
# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
|
|
|
2 |
#
|
3 |
+
# This code is based on transformers/src/transformers/models/llama/modeling_llama.py
|
|
|
|
|
|
|
4 |
#
|
5 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
# you may not use this file except in compliance with the License.
|
|
|
48 |
|
49 |
_CONFIG_FOR_DOC = "InternLMConfig"
|
50 |
|
51 |
+
def _get_unpad_data(attention_mask):
|
52 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
53 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
54 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
55 |
+
cu_seqlens = nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
56 |
+
return (
|
57 |
+
indices,
|
58 |
+
cu_seqlens,
|
59 |
+
max_seqlen_in_batch,
|
60 |
+
)
|
61 |
+
|
62 |
|
63 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
64 |
def _make_causal_mask(
|
|
|
92 |
|
93 |
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
94 |
|
|
|
95 |
class InternLMRMSNorm(nn.Module):
|
96 |
"""RMSNorm implemention."""
|
97 |
|
|
|
234 |
k_sin = sin[position_ids].unsqueeze(1).expand(k.shape)
|
235 |
k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
|
236 |
else:
|
237 |
+
cos = cos[position_ids].unsqueeze(1)
|
238 |
+
sin = sin[position_ids].unsqueeze(1)
|
239 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
240 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
241 |
return q_embed, k_embed
|
|
|
279 |
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
280 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
|
281 |
self.rotary_emb = self._init_rope()
|
282 |
+
self.is_causal = True
|
283 |
|
284 |
def _init_rope(self):
|
285 |
if self.config.rotary["type"] == "origin":
|
|
|
363 |
attn_weights = None
|
364 |
|
365 |
return attn_output, attn_weights, past_key_value
|
366 |
+
|
367 |
+
class InternLMFlashAttention2(InternLMAttention):
|
368 |
+
"""
|
369 |
+
InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
|
370 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
371 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
372 |
+
"""
|
373 |
|
374 |
+
def forward(
|
375 |
+
self,
|
376 |
+
hidden_states: torch.Tensor,
|
377 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
378 |
+
position_ids: Optional[torch.LongTensor] = None,
|
379 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
380 |
+
output_attentions: bool = False,
|
381 |
+
use_cache: bool = False,
|
382 |
+
**kwargs,
|
383 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
384 |
+
# InternLM2FlashAttention2 attention does not support output_attentions
|
385 |
+
bsz, q_len, _ = hidden_states.size()
|
386 |
+
|
387 |
+
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
388 |
+
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
389 |
+
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
390 |
+
|
391 |
+
if past_key_value is not None:
|
392 |
+
# reuse k, v, self_attention
|
393 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
394 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
395 |
+
|
396 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
397 |
+
|
398 |
+
kv_seq_len = key_states.shape[-2]
|
399 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
400 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
401 |
+
|
402 |
+
query_states = query_states.transpose(1, 2)
|
403 |
+
key_states = key_states.transpose(1, 2)
|
404 |
+
value_states = value_states.transpose(1, 2)
|
405 |
+
|
406 |
+
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
407 |
+
|
408 |
+
attn_output = self._flash_attention_forward(
|
409 |
+
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
410 |
+
)
|
411 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
412 |
+
attn_output = self.o_proj(attn_output)
|
413 |
+
|
414 |
+
if not output_attentions:
|
415 |
+
attn_weights = None
|
416 |
+
|
417 |
+
return attn_output, attn_weights, past_key_value
|
418 |
+
|
419 |
+
def _flash_attention_forward(
|
420 |
+
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
421 |
+
):
|
422 |
+
"""
|
423 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
424 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
425 |
+
|
426 |
+
Args:
|
427 |
+
query_states (`torch.Tensor`):
|
428 |
+
Input query states to be passed to Flash Attention API
|
429 |
+
key_states (`torch.Tensor`):
|
430 |
+
Input key states to be passed to Flash Attention API
|
431 |
+
value_states (`torch.Tensor`):
|
432 |
+
Input value states to be passed to Flash Attention API
|
433 |
+
attention_mask (`torch.Tensor`):
|
434 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
435 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
436 |
+
dropout (`int`, *optional*):
|
437 |
+
Attention dropout
|
438 |
+
softmax_scale (`float`, *optional*):
|
439 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
440 |
+
"""
|
441 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
442 |
+
from flash_attn.bert_padding import pad_input
|
443 |
+
# Contains at least one padding token in the sequence
|
444 |
+
causal = self.is_causal and query_length != 1
|
445 |
+
if attention_mask is not None:
|
446 |
+
batch_size = query_states.shape[0]
|
447 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
448 |
+
query_states, key_states, value_states, attention_mask, query_length
|
449 |
+
)
|
450 |
+
|
451 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
452 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
453 |
+
|
454 |
+
attn_output_unpad = flash_attn_varlen_func(
|
455 |
+
query_states,
|
456 |
+
key_states,
|
457 |
+
value_states,
|
458 |
+
cu_seqlens_q=cu_seqlens_q,
|
459 |
+
cu_seqlens_k=cu_seqlens_k,
|
460 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
461 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
462 |
+
dropout_p=dropout,
|
463 |
+
softmax_scale=softmax_scale,
|
464 |
+
causal=causal,
|
465 |
+
)
|
466 |
+
|
467 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
468 |
+
else:
|
469 |
+
attn_output = flash_attn_func(
|
470 |
+
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
471 |
+
)
|
472 |
+
|
473 |
+
return attn_output
|
474 |
+
|
475 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
476 |
+
from flash_attn.bert_padding import index_first_axis, unpad_input
|
477 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
478 |
+
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
479 |
+
|
480 |
+
key_layer = index_first_axis(
|
481 |
+
key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
482 |
+
)
|
483 |
+
value_layer = index_first_axis(
|
484 |
+
value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
485 |
+
)
|
486 |
+
|
487 |
+
if query_length == kv_seq_len:
|
488 |
+
query_layer = index_first_axis(
|
489 |
+
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
490 |
+
)
|
491 |
+
cu_seqlens_q = cu_seqlens_k
|
492 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
493 |
+
indices_q = indices_k
|
494 |
+
elif query_length == 1:
|
495 |
+
max_seqlen_in_batch_q = 1
|
496 |
+
cu_seqlens_q = torch.arange(
|
497 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
498 |
+
) # There is a memcpy here, that is very bad.
|
499 |
+
indices_q = cu_seqlens_q[:-1]
|
500 |
+
query_layer = query_layer.squeeze(1)
|
501 |
+
else:
|
502 |
+
# The -q_len: slice assumes left padding.
|
503 |
+
attention_mask = attention_mask[:, -query_length:]
|
504 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
505 |
+
|
506 |
+
return (
|
507 |
+
query_layer,
|
508 |
+
key_layer,
|
509 |
+
value_layer,
|
510 |
+
indices_q.to(torch.int64),
|
511 |
+
(cu_seqlens_q, cu_seqlens_k),
|
512 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
513 |
+
)
|
514 |
+
|
515 |
+
INTERNLM_ATTENTION_CLASSES = {
|
516 |
+
"eager": InternLMAttention,
|
517 |
+
"flash_attention_2": InternLMFlashAttention2,
|
518 |
+
}
|
519 |
|
520 |
class InternLMDecoderLayer(nn.Module):
|
521 |
def __init__(self, config: InternLMConfig):
|
522 |
super().__init__()
|
523 |
self.hidden_size = config.hidden_size
|
524 |
+
|
525 |
+
self.self_attn = INTERNLM_ATTENTION_CLASSES[config.attn_implementation](config=config)
|
526 |
+
|
527 |
self.mlp = InternLMMLP(
|
528 |
hidden_size=self.hidden_size,
|
529 |
intermediate_size=config.intermediate_size,
|
|
|
700 |
super().__init__(config)
|
701 |
self.padding_idx = config.pad_token_id
|
702 |
self.vocab_size = config.vocab_size
|
703 |
+
self.config = config
|
704 |
|
705 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
706 |
+
|
707 |
self.layers = nn.ModuleList([InternLMDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
708 |
self.norm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
709 |
|
|
|
790 |
|
791 |
if inputs_embeds is None:
|
792 |
inputs_embeds = self.embed_tokens(input_ids)
|
793 |
+
if self.config.attn_implementation == "flash_attention_2":
|
794 |
+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
795 |
+
else:
|
796 |
+
if attention_mask is None:
|
797 |
+
attention_mask = torch.ones(
|
798 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
799 |
+
)
|
800 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
801 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
802 |
)
|
|
|
|
|
|
|
803 |
|
804 |
hidden_states = inputs_embeds
|
805 |
|
|
|
924 |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
925 |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
926 |
Returns:
|
927 |
+
|
928 |
Example:
|
929 |
```python
|
930 |
>>> from transformers import AutoTokenizer, InternLMForCausalLM
|
|
|
936 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
937 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
938 |
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
939 |
+
```
|
940 |
+
|
941 |
+
"""
|
942 |
|
943 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
944 |
output_hidden_states = (
|