Hit OOM finetuning long-t5-tglobal-xl on input_len=14000

#8
by cyenjoylife - opened

I was trying to finetune a long-t5-tglobal-xl on BookSum but got OOM when setting input_len=14000. The training is run on 4 Nvidia A100-80GB. I didn't use fp16 due to overflow in the intermediate hidden states.

Could you please provide me with some details about how you solve the memory problem? Did you place the layers on different devices?

Truly appreciate your fantastic work! Looking forward to your reply!

Hi! Thanks for your compliments and interest in the model! Fine-tuning this is definitely tricky in terms of memory, so you're not alone. For fine-tuning this on 16384 tokens, here's what I used:

  1. bf16 - helps reduce memory usage without causing overflow. See the efficient training on GPU page if you haven't already. AFAIK I have never been able to tune xl in fp32, but comparisons on smaller models like the tglobal-base showed no degradation in fp32 vs. bf16 validation performance - you may want to validate yourself for your data, but this should hold. IIRC Google themselves typically use bfloat16 for training/tuning.
  2. deepspeed - I used ZeRO2, which offloads the optimizer to the CPU. With the configuration below and a decent amount of RAM (128GB+), I was able to fine-tune with a single A100 80GB, albeit slowly. There is some discussion/links to this on the page linked above. I've included my configuration below, which still works, but this is dated ~August 2022, so I'd recommend checking the deepspeed docs for the latest syntax and/or improvements.
  3. tf32 - magic NVIDIA CUDA level data type. As far as I know, this doesn't reduce memory requirements in any significant way, it just speeds things up. As such, it's really more of a counter to mitigate the effects of other memory-saving methods that typically cause your training to slow down. It's also discussed in the linked guide.

Things I didn't try for long-context text2text models

  • Deepspeed ZeRO 3 or newer - Deepspeed Zero 2 works, and brief tests with Zero 3 showed no noticeable improvement over Zero 2, while being significantly more annoying to use. This is probably due to the fact that (for an 'XL' model) there are relatively few parameters in 3B, and the real problem is that there are large gradients/memory requirements from the model inputs rather than traditional memory constraints from batch size or model size.
  • 8-bit optimizers - having the optimizer in 8-bit (either Adam or LION) obviously reduces the memory requirements quite a bit, as this helps with the "large model inputs" problem directly. Implementation is not as easy as it sounds though - you probably need to make a change to the embedding layer or something for stable training. See the link in the previous sentence for details and/or bitsandbytes documentation.
  • flash attention - TogetherComputer published a llama checkpoint with a context length of 32,768 using flash attention. I suspect that Long-T5 could easily handle this, or perhaps even more, if it were to be implemented. That said, it will probably require rewriting most of the modeling code to use flash attention. If you do this, I am all ears and happy to test it :)

Hope this helps!

deepspeed ZeRO-2 config

as stated above this is dated ~August 2022, so I'd recommend checking the deepspeed docs for the latest syntax and/or improvements.

{
    "bfloat16": {
        "enabled": "auto"
    },
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "optimizer": {
        "params": {
            "betas": "auto",
            "eps": "auto",
            "lr": "auto",
            "weight_decay": "auto"
        },
        "type": "AdamW"
    },
    "steps_per_print": 4000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false,
    "zero_optimization": {
        "allgather_bucket_size": 200000000.0,
        "allgather_partitions": true,
        "contiguous_gradients": true,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "reduce_bucket_size": 200000000.0,
        "reduce_scatter": true,
        "round_robin_gradients": true,
        "stage": 2
    }
}
pszemraj pinned discussion

Greatly appreciate your prompt response! Your information is highly informative for beginners looking to train an LLM on lengthy sequences. I'll try your configuration and hope it works for me :)

you're welcome! hope it works out. if you have any more problems feel free to comment here or reopen if there is a problem with this checkpoint!

pszemraj changed discussion status to closed

Sign up or log in to comment