How to use Gradient Checkpointing with Hugging Face models

nlp
code
Author

Hamish Dickson

Published

March 23, 2024

Imagine you are GPU poor and would like to train a huge model, how do you do it?

Well your first approach is probably to use gradient accumulation, using multiple forward passes for each backward pass. When you average out the losses you can get very similar results to a bigger batch.

But … what if you can’t even do a single forward pass? You just don’t have enough GPU vRAM?

Gradient checkpointing is an easy way to get around this. Here is what you need to do, when you declare your model just add model.gradient_checkpointing_enable()

import transformers
# note: you need to import this line, it's missing from almost all documentation
import torch.utils.checkpoint

model = transformers.AutoModel.from_pretrained("my_huge_model")
model.gradient_checkpointing_enable()

And that’s it, you get much more freedom to train your model.

So what is this doing? Well the short answer is we are trading speed for memory and we are doing it by moving some of our intermediate data off of our GPU and into memory.

You should play about with batch sizes here to try and find a balance between speed and batch size.