Super slow loading compared to other (even bigger) models
Loading the model with the following code (consecutive time):
generate = InstructionTextGenerationPipeline(
"mosaicml/mpt-7b-instruct",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
takes way, way more time than loading other, even bigger models (like databricks/dolly-v2-12b
) on the same hardware.
I observe that one thread is spinning 100% of single core for couple of minutes.
can confirm this
You can try setting in config.json to "use_cache": true
that's worked for me in other models...
I haven't tried this model yet. ymmv etc.
Yup
takes way, way more time than loading other, even bigger models (like
databricks/dolly-v2-12b
) on the same hardware.
It seems to be the casting as loading in 32bit is way faster. Apparently the convertion happens in just one thread and very inefficiently allocates the memory.
Both fp16 and bfp16 are equally affected.
Hi, could someone whose system is seeing a slowdown confirm if the issue affects other HF models when using torch_dtype: bfloat16 or float16
? Basically I would like to know if this is correct:
- MPT-7B, torch_dtype=torch.float32 (fast)
- MPT-7B, torch_dtype=torch.bfloat16 (slow)
- OPT-6.7B, torch_dtype=torch.float32 (fast)
- OPT-6.7B, torch_dtype=torch.float16 (slow)
if this is true then I will chalk it up to how torch / HF handle the dtype conversion, as we aren't doing anything special with MPT other than saving the weights in BF16 (similar to OPT saving the weights in FP16).
if this is true then I will chalk it up to how torch / HF handle the dtype conversion, as we aren't doing anything special with MPT other than saving the weights in BF16 (similar to OPT saving the weights in FP16).
This is actual execution order. I made sure the models are downloaded beforehand.
def benchmark(model, drype):
model = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype=drype,
trust_remote_code=True,
use_auth_token=False,
)
%%time
benchmark("facebook/opt-6.7b", torch.float32)
CPU times: user 55.7 s, sys: 31.4 s, total: 1min 27s
Wall time: 1min
%%time
benchmark("facebook/opt-6.7b", torch.bfloat16)
CPU times: user 1min 6s, sys: 17.9 s, total: 1min 24s
Wall time: 1min 8s
###
%%time
benchmark("mosaicml/mpt-7b-instruct", torch.float32)
CPU times: user 1min 54s, sys: 25 s, total: 2min 19s
Wall time: 1min 31s
%%time
benchmark("mosaicml/mpt-7b-instruct", torch.bfloat16)
CPU times: user 4min 13s, sys: 26.8 s, total: 4min 40s
Wall time: 3min 50s
@abhi-mosaic Are you looking into that? It takes almost 3 times more time to load mpt-7b in bfloat16 and almost 4 times more than opt-6.7b.
The fix for me was setting low_cpu_mem_usage=True, this would normally be the default is device auto mapping was turned on. It took the load time from 3 minutes to 6 seconds.
The fix for me was setting low_cpu_mem_usage=True, this would normally be the default is device auto mapping was turned on. It took the load time from 3 minutes to 6 seconds.
It does the trick for practical purposes (now model loads in mere seconds) but the mystery remains unsolved :)
I don't have the traces in front of me but I'm fairly sure it was doing 32->b16 in python on cpu (twice). I think that flag will make it just load in the final weights without blatting out the initial empty state.
Not sure why people aren't hitting this; I suspect accelerate will trigger a similar effect, I was running in a container so I may not have had it installed.
Just wanted to note that device_map
support and faster KV cacheing has been added in this PR: https://huggingface.co./mosaicml/mpt-7b-instruct/discussions/41
I think now, finally, if you load the model with device_map=auto
on a CPU machine, it should go really fast because it's avoiding the random weight init.