RuntimeError: FlashAttention only support fp16 and bf16 data type
#15
by
Satandon1999
- opened
I am not able to resolve this error while trying to finetune this model.
I have loaded the model as bf16 using torch_dtype=torch.bfloat16
in the from_pretrained
function. I have also added an explicit cast using model = model.to(torch.bfloat16)
.
The same exact code works flawlessly for the 'mini' version of the model, but not for this.
Any guidance would be greatly appreciated.
Thanks.
see my command in the 4k model here
https://huggingface.co./microsoft/Phi-3-small-8k-instruct/discussions/11#6661de1a62de925acf74516f