RuntimeError: FlashAttention only support fp16 and bf16 data type with Flash attention2

#33
by liougehooa - opened

When I train with QLora + Flash attention, it has this error. But if I train with Lora + Flash attention, it doesn't.
Here's the code snippets with QLora:

base_model_id = "microsoft/Phi-3.5-vision-instruct"
# Initialize model
bnb_config = transformers.BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
)

model = transformers.AutoModelForCausalLM.from_pretrained(
          base_model_id, 
          torch_dtype= dtype, ## torch.bfloat16, 
          trust_remote_code=True, 
          quantization_config=bnb_config,
          use_flash_attention_2=True
).to(device)

processor = transformers.AutoProcessor.from_pretrained(base_model_id, trust_remote_code=True)


training_args = transformers.TrainingArguments(
    num_train_epochs=40,                          # Number of training epochs
    per_device_train_batch_size=batch_size,      # Batch size for training
    per_device_eval_batch_size=batch_size,       # Batch size for evaluation
    gradient_accumulation_steps=2,               # Number of steps to accumulate gradients before updating
    gradient_checkpointing=True,                 # Enable gradient checkpointing to save memory
    do_eval=True,                                # Perform evaluation during training
    save_total_limit=2,                          # Limit the total number of saved checkpoints
    evaluation_strategy="steps",                 # Evaluation strategy to use (here, at each specified number of steps)
    save_strategy="steps",                       # Save checkpoints at each specified number of steps
    save_steps=10,                               # Number of steps between each checkpoint save
    eval_steps=10,                               # Number of steps between each evaluation
    max_grad_norm=1,                             # Maximum gradient norm for clipping
    warmup_ratio=0.1,                            # Warmup ratio for learning rate schedule
    weight_decay=0.001,                          # Regularization technique to prevent overfitting
    # fp16=True,                                 # Enable mixed precision training with fp16 (enable it if Ampere architecture is unavailable)
    bf16=True,                                   # Enable mixed precision training with bf16
    logging_steps=10,                            # Number of steps between each log
    output_dir="outputs",                        # Directory to save the model outputs and checkpoints
    optim="adamw_torch",                         # Optimizer to use (AdamW with PyTorch)
    learning_rate=5e-5,                          # Learning rate for the optimizer
    lr_scheduler_type="linear",                  # Learning rate scheduler type: constant
    load_best_model_at_end=True,                 # Load the best model found during training at the end
    metric_for_best_model="rouge",               # Metric used to determine the best model
    greater_is_better=True,                      # Indicates if a higher metric score is better
    push_to_hub=False,                           # Whether to push the model to Hugging Face Hub
    run_name="phi-3-5-vision-finetuning",   # Name of the run for experiment tracking
    report_to="wandb"                            # For experiment tracking (login to Weights & Biases needed)
)

class CustomTrainer(transformers.Trainer):
    def get_train_dataloader(self):
       ...
    
    def get_eval_dataloader(self, eval_dataset=None):
        ...

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predicted = logits.argmax(-1)
    labels = np.where(labels != -100, labels, processor.tokenizer.pad_token_id)

    decoded_labels = processor.batch_decode(labels, skip_special_tokens=True)
    decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True)
    rouge_scores = rouge.compute(predictions=decoded_predictions, references=decoded_labels)
    rouge1_score = rouge_scores["rouge1"]
    return {"rouge": rouge1_score}

# Ensure the model is in training mode
peft_model.train()

trainer = CustomTrainer(
    model=peft_model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

peft_model.config.use_cache = False

trainer.train()

The error:

File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py:51, in _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax)
     49 maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
     50 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
---> 51 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
     52     q,
     53     k,
     54     v,
     55     None,
     56     alibi_slopes,
     57     dropout_p,
     58     softmax_scale,
     59     causal,
     60     window_size[0],
     61     window_size[1],
     62     return_softmax,
     63     None,
     64 )
     65 return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
RuntimeError: FlashAttention only support fp16 and bf16 data type

There's similiar problem here:
https://huggingface.co./microsoft/Phi-3-small-8k-instruct/discussions/11

When I added the conversion the code in CLIPAttentionFA2(modeling_phi3_v.py) for a test:

query_states = self.q_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim)
key_states = self.k_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim)
value_states = self.v_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim)

## added conversion code before apply to flash_attn_func. You can manually set target_dtype to your target dtype, eg, torch.bfloat16
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

The above can work.

lib versions:
Flash attention2: 2.5.8
transformer:4.45.2
peft: 0.11.1
bitsandbytes: 0.44.1

The whole error

RuntimeError                              Traceback (most recent call last)

----> 1 trainer.train()
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/transformers/trainer.py:2123, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2121         hf_hub_utils.enable_progress_bars()
   2122 else:
-> 2123     return inner_training_loop(
   2124         args=args,
   2125         resume_from_checkpoint=resume_from_checkpoint,
   2126         trial=trial,
   2127         ignore_keys_for_eval=ignore_keys_for_eval,
   2128     )
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/transformers/trainer.py:2481, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2475 context = (
   2476     functools.partial(self.accelerator.no_sync, model=model)
   2477     if i == len(batch_samples) - 1
   2478     else contextlib.nullcontext
   2479 )
   2480 with context():
-> 2481     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
   2483 if (
   2484     args.logging_nan_inf_filter
   2485     and not is_torch_xla_available()
   2486     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   2487 ):
   2488     # if loss is nan or inf simply add the average of previous logged losses
   2489     tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/transformers/trainer.py:3579, in Trainer.training_step(self, model, inputs, num_items_in_batch)
   3576     return loss_mb.reduce_mean().detach().to(self.args.device)
   3578 with self.compute_loss_context_manager():
-> 3579     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
   3581 del inputs
   3582 if (
   3583     self.args.torch_empty_cache_steps is not None
   3584     and self.state.global_step % self.args.torch_empty_cache_steps == 0
   3585 ):
Cell In[38], line 24, in CustomTrainer.compute_loss(self, model, inputs, num_items_in_batch, return_outputs)
     23 def compute_loss(self, model, inputs, num_items_in_batch=0, return_outputs=False):
---> 24     outputs = model(**inputs)
     25     loss = outputs.loss if isinstance(outputs, dict) else outputs[0]
     26     return (loss, outputs) if return_outputs else loss
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/utils/operations.py:823, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    822 def forward(*args, **kwargs):
--> 823     return model_forward(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/utils/operations.py:811, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    810 def __call__(self, *args, **kwargs):
--> 811     return convert_to_fp32(self.model_forward(*args, **kwargs))
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     13 @functools.wraps(func)
     14 def decorate_autocast(*args, **kwargs):
     15     with autocast_instance:
---> 16         return func(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/peft/peft_model.py:1577, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1575     with self._enable_peft_forward_hooks(**kwargs):
   1576         kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1577         return self.base_model(
   1578             input_ids=input_ids,
   1579             attention_mask=attention_mask,
   1580             inputs_embeds=inputs_embeds,
   1581             labels=labels,
   1582             output_attentions=output_attentions,
   1583             output_hidden_states=output_hidden_states,
   1584             return_dict=return_dict,
   1585             **kwargs,
   1586         )
   1588 batch_size = _get_batch_size(input_ids, inputs_embeds)
   1589 if attention_mask is not None:
   1590     # concat prompt attention mask
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/peft/tuners/tuners_utils.py:188, in BaseTuner.forward(self, *args, **kwargs)
    187 def forward(self, *args: Any, **kwargs: Any):
--> 188     return self.model.forward(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)
File ~/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3.5-vision-instruct/4a0d683eba9f1d0cbfb6151705d1ee73c25a80ca/modeling_phi3_v.py:1603, in Phi3VForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, image_sizes, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1600 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1602 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1603 outputs = self.model(
   1604     input_ids=input_ids,
   1605     attention_mask=attention_mask,
   1606     position_ids=position_ids,
   1607     past_key_values=past_key_values,
   1608     inputs_embeds=inputs_embeds,
   1609     pixel_values=pixel_values,
   1610     image_sizes=image_sizes,
   1611     use_cache=use_cache,
   1612     output_attentions=output_attentions,
   1613     output_hidden_states=output_hidden_states,
   1614     return_dict=return_dict,
   1615 )
   1617 hidden_states = outputs[0]
   1618 logits = self.lm_head(hidden_states)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)
File ~/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3.5-vision-instruct/4a0d683eba9f1d0cbfb6151705d1ee73c25a80ca/modeling_phi3_v.py:1431, in Phi3VModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, image_sizes, use_cache, output_attentions, output_hidden_states, return_dict)
   1429 if pixel_values is not None and image_sizes is not None:
   1430     assert self.vision_embed_tokens is not None, "Vision embedding layer is not defined"
-> 1431     inputs_embeds = self.vision_embed_tokens(input_ids, pixel_values=pixel_values, image_sizes=image_sizes)
   1432 else:
   1433     inputs_embeds = self.embed_tokens(input_ids)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)
File ~/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3.5-vision-instruct/4a0d683eba9f1d0cbfb6151705d1ee73c25a80ca/modeling_phi3_v.py:237, in Phi3ImageEmbedding.forward(self, input_ids, pixel_values, image_sizes)
    235 num_images, num_crops, c, h, w = pixel_values.shape
    236 assert c == 3 and h == w == 336
--> 237 img_features = self.get_img_features(pixel_values.flatten(0, 1)).reshape(
    238     num_images, num_crops, -1, self.image_dim_out
    239 )
    240 image_features_proj = self.hd_feature_transform(img_features, image_sizes)
    241 hidden_states = hidden_states.index_put(
    242     positions, image_features_proj, accumulate=False
    243 )
File ~/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3.5-vision-instruct/4a0d683eba9f1d0cbfb6151705d1ee73c25a80ca/modeling_phi3_v.py:212, in Phi3ImageEmbedding.get_img_features(self, img_embeds)
    209 LAYER_IDX = self.layer_idx
    210 TYPE_FEATURE = self.type_feature
--> 212 img_processor_output = self.img_processor(img_embeds, output_hidden_states=True)
    213 img_feature = img_processor_output.hidden_states[LAYER_IDX]
    215 if TYPE_FEATURE == "patch":
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py:1171, in CLIPVisionModel.forward(self, pixel_values, output_attentions, output_hidden_states, interpolate_pos_encoding, return_dict)
   1147 r"""
   1148 Returns:
   1149 
   (...)
   1167 >>> pooled_output = outputs.pooler_output  # pooled CLS states
   1168 ```"""
   1169 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-> 1171 return self.vision_model(
   1172     pixel_values=pixel_values,
   1173     output_attentions=output_attentions,
   1174     output_hidden_states=output_hidden_states,
   1175     return_dict=return_dict,
   1176     interpolate_pos_encoding=interpolate_pos_encoding,
   1177 )
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py:1097, in CLIPVisionTransformer.forward(self, pixel_values, output_attentions, output_hidden_states, return_dict, interpolate_pos_encoding)
   1094 hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
   1095 hidden_states = self.pre_layrnorm(hidden_states)
-> 1097 encoder_outputs = self.encoder(
   1098     inputs_embeds=hidden_states,
   1099     output_attentions=output_attentions,
   1100     output_hidden_states=output_hidden_states,
   1101     return_dict=return_dict,
   1102 )
   1104 last_hidden_state = encoder_outputs[0]
   1105 pooled_output = last_hidden_state[:, 0, :]
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py:869, in CLIPEncoder.forward(self, inputs_embeds, attention_mask, causal_attention_mask, output_attentions, output_hidden_states, return_dict)
    867     encoder_states = encoder_states + (hidden_states,)
    868 if self.gradient_checkpointing and self.training:
--> 869     layer_outputs = self._gradient_checkpointing_func(
    870         encoder_layer.__call__,
    871         hidden_states,
    872         attention_mask,
    873         causal_attention_mask,
    874         output_attentions,
    875     )
    876 else:
    877     layer_outputs = encoder_layer(
    878         hidden_states,
    879         attention_mask,
    880         causal_attention_mask,
    881         output_attentions=output_attentions,
    882     )
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/_compile.py:24, in _disable_dynamo.<locals>.inner(*args, **kwargs)
     20 @functools.wraps(fn)
     21 def inner(*args, **kwargs):
     22     import torch._dynamo
---> 24     return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:451, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    449 prior = set_eval_frame(callback)
    450 try:
--> 451     return fn(*args, **kwargs)
    452 finally:
    453     set_eval_frame(prior)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/_dynamo/external_utils.py:36, in wrap_inline.<locals>.inner(*args, **kwargs)
     34 @functools.wraps(fn)
     35 def inner(*args, **kwargs):
---> 36     return fn(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/utils/checkpoint.py:487, in checkpoint(function, use_reentrant, context_fn, determinism_check, debug, *args, **kwargs)
    482     if context_fn is not noop_context_fn or debug is not False:
    483         raise ValueError(
    484             "Passing `context_fn` or `debug` is only supported when "
    485             "use_reentrant=False."
    486         )
--> 487     return CheckpointFunction.apply(function, preserve, *args)
    488 else:
    489     gen = _checkpoint_without_reentrant_generator(
    490         function, preserve, context_fn, determinism_check, debug, *args, **kwargs
    491     )
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/autograd/function.py:598, in Function.apply(cls, *args, **kwargs)
    595 if not torch._C._are_functorch_transforms_active():
    596     # See NOTE: [functorch vjp and autograd interaction]
    597     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 598     return super().apply(*args, **kwargs)  # type: ignore[misc]
    600 if not is_setup_ctx_defined:
    601     raise RuntimeError(
    602         "In order to use an autograd.Function with functorch transforms "
    603         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    604         "staticmethod. For more details, please see "
    605         " https://pytorch.org/docs/master/notes/extending.func.html"
    606     )
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/utils/checkpoint.py:262, in CheckpointFunction.forward(ctx, run_function, preserve_rng_state, *args)
    259 ctx.save_for_backward(*tensor_inputs)
    261 with torch.no_grad():
--> 262     outputs = run_function(*args)
    263 return outputs
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py:608, in CLIPEncoderLayer.forward(self, hidden_states, attention_mask, causal_attention_mask, output_attentions)
    605 residual = hidden_states
    607 hidden_states = self.layer_norm1(hidden_states)
--> 608 hidden_states, attn_weights = self.self_attn(
    609     hidden_states=hidden_states,
    610     attention_mask=attention_mask,
    611     causal_attention_mask=causal_attention_mask,
    612     output_attentions=output_attentions,
    613 )
    614 hidden_states = residual + hidden_states
    616 residual = hidden_states
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)
File ~/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3.5-vision-instruct/4a0d683eba9f1d0cbfb6151705d1ee73c25a80ca/modeling_phi3_v.py:105, in CLIPAttentionFA2.forward(self, hidden_states, attention_mask, causal_attention_mask, output_attentions)
    102 key_states = self.k_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim)
    103 value_states = self.v_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim)
--> 105 attn_output = flash_attn_func(
    106     query_states,
    107     key_states,
    108     value_states,
    109     dropout_p=self.dropout if self.training else 0.0,
    110     softmax_scale=self.scale,
    111     causal=False,
    112 ).reshape(bsz, tgt_len, embed_dim)
    114 attn_output = self.out_proj(attn_output)
    115 return attn_output, None
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py:831, in flash_attn_func(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_attn_probs)
    771 def flash_attn_func(
    772     q,
    773     k,
   (...)
    781     return_attn_probs=False,
    782 ):
    783     """dropout_p should be set to 0.0 during evaluation
    784     Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    785     than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
   (...)
    829             pattern (negative means that location was dropped, nonnegative means it was kept).
    830     """
--> 831     return FlashAttnFunc.apply(
    832         q,
    833         k,
    834         v,
    835         dropout_p,
    836         softmax_scale,
    837         causal,
    838         window_size,
    839         alibi_slopes,
    840         deterministic,
    841         return_attn_probs,
    842     )
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/torch/autograd/function.py:598, in Function.apply(cls, *args, **kwargs)
    595 if not torch._C._are_functorch_transforms_active():
    596     # See NOTE: [functorch vjp and autograd interaction]
    597     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 598     return super().apply(*args, **kwargs)  # type: ignore[misc]
    600 if not is_setup_ctx_defined:
    601     raise RuntimeError(
    602         "In order to use an autograd.Function with functorch transforms "
    603         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    604         "staticmethod. For more details, please see "
    605         " https://pytorch.org/docs/master/notes/extending.func.html"
    606     )
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py:511, in FlashAttnFunc.forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax)
    509 if softmax_scale is None:
    510     softmax_scale = q.shape[-1] ** (-0.5)
--> 511 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
    512     q,
    513     k,
    514     v,
    515     dropout_p,
    516     softmax_scale,
    517     causal=causal,
    518     window_size=window_size,
    519     alibi_slopes=alibi_slopes,
    520     return_softmax=return_softmax and dropout_p > 0,
    521 )
    522 ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
    523 ctx.dropout_p = dropout_p
File /anaconda/envs/azureml_py38_PT_TF/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py:51, in _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax)
     49 maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
     50 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
---> 51 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
     52     q,
     53     k,
     54     v,
     55     None,
     56     alibi_slopes,
     57     dropout_p,
     58     softmax_scale,
     59     causal,
     60     window_size[0],
     61     window_size[1],
     62     return_softmax,
     63     None,
     64 )
     65 return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
RuntimeError: FlashAttention only support fp16 and bf16 data type

Sign up or log in to comment