zR commited on
Commit
fe11ac1
1 Parent(s): e190b08

flash attn support

Browse files
Files changed (2) hide show
  1. config.json +1 -0
  2. modeling_chatglm.py +149 -5
config.json CHANGED
@@ -31,6 +31,7 @@
31
  "apply_residual_connection_post_layernorm": false,
32
  "attention_dropout": 0.0,
33
  "attention_softmax_in_fp32": true,
 
34
  "bias_dropout_fusion": true,
35
  "ffn_hidden_size": 13696,
36
  "fp32_residual_connection": false,
 
31
  "apply_residual_connection_post_layernorm": false,
32
  "attention_dropout": 0.0,
33
  "attention_softmax_in_fp32": true,
34
+ "attn_implementation": "sdpa",
35
  "bias_dropout_fusion": true,
36
  "ffn_hidden_size": 13696,
37
  "fp32_residual_connection": false,
modeling_chatglm.py CHANGED
@@ -21,16 +21,21 @@ from transformers.modeling_outputs import (
21
  SequenceClassifierOutputWithPast,
22
  )
23
  from transformers.modeling_utils import PreTrainedModel
24
- from transformers.utils import logging
 
25
  from transformers.generation.logits_process import LogitsProcessor
26
  from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
27
 
28
  from .configuration_chatglm import ChatGLMConfig
29
  from .visual import EVA2CLIPModel
30
 
 
 
 
 
31
  # flags required to enable jit fusion kernels
32
 
33
- if sys.platform != 'darwin':
34
  torch._C._jit_set_profiling_mode(False)
35
  torch._C._jit_set_profiling_executor(False)
36
  torch._C._jit_override_can_fuse_on_cpu(True)
@@ -44,6 +49,7 @@ VISION_TOKEN_TYPE = 1
44
  _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
45
  _CONFIG_FOR_DOC = "ChatGLMConfig"
46
 
 
47
  def default_init(cls, *args, **kwargs):
48
  return cls(*args, **kwargs)
49
 
@@ -323,6 +329,130 @@ class CoreAttention(torch.nn.Module):
323
  return context_layer
324
 
325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  class SelfAttention(torch.nn.Module):
327
  """Parallel self-attention layer abstract class.
328
 
@@ -687,12 +817,18 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
687
  config_class = ChatGLMConfig
688
  base_model_prefix = "transformer"
689
  _no_split_modules = ["GLMBlock"]
 
 
690
 
691
  def _init_weights(self, module: nn.Module):
692
  """Initialize the weights."""
693
  return
694
 
695
  def get_masks(self, input_embeds, past_key_values, padding_mask=None):
 
 
 
 
696
  batch_size, seq_length, embed_size = input_embeds.shape
697
  full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_embeds.device)
698
  full_attention_mask.tril_()
@@ -839,6 +975,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
839
  # not allow for inputs_embeds, because we want to process image feature
840
  assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
841
  if not is_empty(images): # multi-modality
 
842
  image_size: int = self.config.vision_config['image_size']
843
  patch_size: int = self.config.vision_config['patch_size']
844
  num_patches = (image_size // patch_size // 2) ** 2
@@ -858,7 +995,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
858
  self.config.eoi_token_id)
859
  assert eoi_token_pos - boi_token_pos == 2
860
  new_input_embeds.append(torch.cat(
861
- (inputs_embeds[i, :boi_token_pos], images_features[i].to(inputs_embeds.device), inputs_embeds[i, eoi_token_pos + 1:])))
 
862
  new_position_ids.append(torch.cat(
863
  (position_ids[i, :boi_token_pos + 1], position_ids[i, boi_token_pos + 1].repeat(num_patches),
864
  position_ids[i, eoi_token_pos:])
@@ -981,10 +1119,16 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
981
  patch_size: int = self.config.vision_config['patch_size']
982
  num_patches = (image_size // patch_size // 2) ** 2
983
  new_attention_masks = []
 
 
 
 
 
984
  for i in range(len(input_ids)):
985
  input_id = input_ids[i].tolist()
986
- boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index(
987
- self.config.eoi_token_id)
 
988
  assert eoi_token_pos - boi_token_pos == 2
989
  new_attention_masks.append(torch.cat(
990
  (attention_mask[i, :boi_token_pos + 1], attention_mask.new_ones(num_patches),
 
21
  SequenceClassifierOutputWithPast,
22
  )
23
  from transformers.modeling_utils import PreTrainedModel
24
+ from transformers.utils import logging, is_torch_npu_available, is_flash_attn_greater_or_equal_2_10, \
25
+ is_flash_attn_2_available
26
  from transformers.generation.logits_process import LogitsProcessor
27
  from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
28
 
29
  from .configuration_chatglm import ChatGLMConfig
30
  from .visual import EVA2CLIPModel
31
 
32
+ if is_flash_attn_2_available():
33
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
34
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
35
+
36
  # flags required to enable jit fusion kernels
37
 
38
+ if sys.platform != 'darwin' and not is_torch_npu_available():
39
  torch._C._jit_set_profiling_mode(False)
40
  torch._C._jit_set_profiling_executor(False)
41
  torch._C._jit_override_can_fuse_on_cpu(True)
 
49
  _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
50
  _CONFIG_FOR_DOC = "ChatGLMConfig"
51
 
52
+
53
  def default_init(cls, *args, **kwargs):
54
  return cls(*args, **kwargs)
55
 
 
329
  return context_layer
330
 
331
 
332
+ class SdpaAttention(CoreAttention):
333
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
334
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
335
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
336
+ is_causal=True,
337
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
338
+ else:
339
+ if attention_mask is not None:
340
+ attention_mask = ~attention_mask
341
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
342
+ attention_mask,
343
+ dropout_p=self.config.attention_dropout if self.training else 0.0)
344
+ context_layer = context_layer.transpose(1, 2).contiguous()
345
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
346
+ context_layer = context_layer.reshape(*new_context_layer_shape)
347
+ return context_layer
348
+
349
+
350
+ def _get_unpad_data(attention_mask):
351
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
352
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
353
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
354
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
355
+ return (
356
+ indices,
357
+ cu_seqlens,
358
+ max_seqlen_in_batch,
359
+ )
360
+
361
+
362
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2
363
+ class FlashAttention2(CoreAttention):
364
+ def __init__(self, *args, **kwargs):
365
+ super().__init__(*args, **kwargs)
366
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
367
+
368
+ def forward(self, query_states, key_states, value_states, attention_mask):
369
+ query_states = query_states.transpose(1, 2)
370
+ key_states = key_states.transpose(1, 2)
371
+ value_states = value_states.transpose(1, 2)
372
+ batch_size, query_length = query_states.shape[:2]
373
+ if not self._flash_attn_uses_top_left_mask:
374
+ causal = self.is_causal
375
+ else:
376
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
377
+ causal = self.is_causal and query_length != 1
378
+ dropout = self.config.attention_dropout if self.training else 0.0
379
+ # Contains at least one padding token in the sequence
380
+ if attention_mask is not None:
381
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
382
+ query_states, key_states, value_states, attention_mask, query_length
383
+ )
384
+
385
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
386
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
387
+
388
+ attn_output_unpad = flash_attn_varlen_func(
389
+ query_states,
390
+ key_states,
391
+ value_states,
392
+ cu_seqlens_q=cu_seqlens_q,
393
+ cu_seqlens_k=cu_seqlens_k,
394
+ max_seqlen_q=max_seqlen_in_batch_q,
395
+ max_seqlen_k=max_seqlen_in_batch_k,
396
+ dropout_p=dropout,
397
+ softmax_scale=None,
398
+ causal=causal,
399
+ )
400
+
401
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
402
+ else:
403
+ attn_output = flash_attn_func(
404
+ query_states, key_states, value_states, dropout, softmax_scale=None, causal=causal
405
+ )
406
+ attn_output = attn_output.reshape(batch_size, query_length, self.hidden_size_per_partition).contiguous()
407
+ return attn_output
408
+
409
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
410
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
411
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
412
+
413
+ key_layer = index_first_axis(
414
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
415
+ )
416
+ value_layer = index_first_axis(
417
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
418
+ )
419
+ if query_length == kv_seq_len:
420
+ query_layer = index_first_axis(
421
+ query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim),
422
+ indices_k
423
+ )
424
+ cu_seqlens_q = cu_seqlens_k
425
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
426
+ indices_q = indices_k
427
+ elif query_length == 1:
428
+ max_seqlen_in_batch_q = 1
429
+ cu_seqlens_q = torch.arange(
430
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
431
+ ) # There is a memcpy here, that is very bad.
432
+ indices_q = cu_seqlens_q[:-1]
433
+ query_layer = query_layer.squeeze(1)
434
+ else:
435
+ # The -q_len: slice assumes left padding.
436
+ attention_mask = attention_mask[:, -query_length:]
437
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
438
+
439
+ return (
440
+ query_layer,
441
+ key_layer,
442
+ value_layer,
443
+ indices_q,
444
+ (cu_seqlens_q, cu_seqlens_k),
445
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
446
+ )
447
+
448
+
449
+ CORE_ATTENTION_CLASSES = {
450
+ "eager": CoreAttention,
451
+ "sdpa": SdpaAttention,
452
+ "flash_attention_2": FlashAttention2
453
+ }
454
+
455
+
456
  class SelfAttention(torch.nn.Module):
457
  """Parallel self-attention layer abstract class.
458
 
 
817
  config_class = ChatGLMConfig
818
  base_model_prefix = "transformer"
819
  _no_split_modules = ["GLMBlock"]
820
+ _supports_flash_attn_2 = True
821
+ _supports_sdpa = True
822
 
823
  def _init_weights(self, module: nn.Module):
824
  """Initialize the weights."""
825
  return
826
 
827
  def get_masks(self, input_embeds, past_key_values, padding_mask=None):
828
+ if self.config._attn_implementation == "flash_attention_2":
829
+ if padding_mask is not None and not padding_mask.all():
830
+ return padding_mask
831
+ return None
832
  batch_size, seq_length, embed_size = input_embeds.shape
833
  full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_embeds.device)
834
  full_attention_mask.tril_()
 
975
  # not allow for inputs_embeds, because we want to process image feature
976
  assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
977
  if not is_empty(images): # multi-modality
978
+
979
  image_size: int = self.config.vision_config['image_size']
980
  patch_size: int = self.config.vision_config['patch_size']
981
  num_patches = (image_size // patch_size // 2) ** 2
 
995
  self.config.eoi_token_id)
996
  assert eoi_token_pos - boi_token_pos == 2
997
  new_input_embeds.append(torch.cat(
998
+ (inputs_embeds[i, :boi_token_pos], images_features[i].to(inputs_embeds.device),
999
+ inputs_embeds[i, eoi_token_pos + 1:])))
1000
  new_position_ids.append(torch.cat(
1001
  (position_ids[i, :boi_token_pos + 1], position_ids[i, boi_token_pos + 1].repeat(num_patches),
1002
  position_ids[i, eoi_token_pos:])
 
1119
  patch_size: int = self.config.vision_config['patch_size']
1120
  num_patches = (image_size // patch_size // 2) ** 2
1121
  new_attention_masks = []
1122
+
1123
+ # if not image, use this default id
1124
+ eoi_token_pos = 6
1125
+ boi_token_pos = 4
1126
+
1127
  for i in range(len(input_ids)):
1128
  input_id = input_ids[i].tolist()
1129
+ if not is_empty(images):
1130
+ boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index(
1131
+ self.config.eoi_token_id)
1132
  assert eoi_token_pos - boi_token_pos == 2
1133
  new_attention_masks.append(torch.cat(
1134
  (attention_mask[i, :boi_token_pos + 1], attention_mask.new_ones(num_patches),