How to prevent "CUDA Ouft of Memory" when using large batch size?

Hi there!

I have a question that might be simple to you: when trying to use a very large batch size (say 5000 for cifar10), what can I set in PyTorch in order to prevent the GPU from out of memory?
Like using a dtype that taking less memory or setting smaller num_of_worker?


You could use automatic mixed-precision to use float16, where applicable, or apply torch.utils.checkpoint to trade compute for memory.

Changing the data type manually to float16 or another (low-precision) format might work, but you would have to take care of the stability of all operations in your model, e.g. since the values might over-/underflow more easily.

The number of workers in a DataLoader doesn’t change the GPU memory usage.

1 Like