Explanation of exact effect of AMP on memory usage

I’m writing a blog post breaking down Pytorch memory usage at each step of training, including special cases like training with mixed precision. I have come up with an accurate memory model but it fails in the case of AMP. Here is a link to a Colab notebook with the memory model, where you can replicate what I describe in this post.

From what I understand of how mixed precision works, it should save you exactly half of the memory that the forward pass would normally take (see this and this and this - this last one uses apex, not the native pytorch amp, but is the first empirical evidence I have seen that amp decreases forward pass memory by half). This is because the forward activations are stored in fp16, although the gradients and model weights are in fp32. However in experiments that I have run, I have seen erratic memory usage for AMP. I would like to understand what is going on, so that I can consistently get the memory savings that I expect.

I have written a code to estimate maximum memory usage based on the model size, the input size, the batch size, the optimizer type, and whether AMP is being used. The first function estimates the memory and the second function runs the code to measure the actual memory usage. The estimator function is accurate to the true memory usage, except when use_amp=True.

One note: There are two different measures of memory usage that are relevant here, and they don’t always line up with each other. 1: The memory reported by torch.cuda.memory_allocated. 2: The memory shown by nvidia-smi after the run.

Here are six runs with slightly different parameters, with different outcomes, and my observations.

Run 1:
Function: test_memory(batch_size=64, use_amp=False)
Estimated max memory: 38121984 (this includes the model, gradients, optimizer momentum values, and forward pass).
True max torch.cuda.memory_allocated: 38121984 (exactly correct)
nvidia-smi memory after run: 501MB
Observation: the prediction works as expected without use_amp

Run 2:
Function: test_memory(batch_size=64, use_amp=True)
Estimated max memory: 35536128 (reducing forward pass memory estimate by 1/2).
True max torch.cuda.memory_allocated: 39634432
nvidia-smi memory after run: 503MB
Observation: max allocated memory is greater than without use_amp. This is the opposite of what I would expect.

Run 3:
Function: test_memory(batch_size=128000, use_amp=False)
Estimated max memory: 10375350784
True max torch.cuda.memory_allocated: 10375350784 (exactly correct)
nvidia-smi memory after run: 10667MiB / 11441MiB
Observation: the prediction works as expected without use_amp. The only difference is that with this much larger batch, according to nvidia-smi, CUDA appears to cache much more data.

Run 4:
Function: test_memory(batch_size=140000, use_amp=True)
Estimated max memory: 5751884032
True max torch.cuda.memory_allocated: 5723155456 (very similar to estimation which uses 1/2 multiplier for forward pass memory).
nvidia-smi memory after run: 11429MiB / 11441MiB
Observation: torch.cuda.memory_allocated is pretty close to the prediction. But the memory reflected in nvidia-smi is about double the memory shown by torch.cuda.memory_allocated.

Run 5:
Function: test_memory(batch_size=140000, use_amp=False)
Estimated max memory: 11470817792 (more than the available memory).
True max torch.cuda.memory_allocated: [Memory Error]
nvidia-smi memory after run: 11427MiB / 11441MiB
Observation: as expected, the GPU runs out of memory. When compared with run 4, this shows that use_amp does save a small amount of memory.

Run 6:
Function: test_memory(batch_size=150000, use_amp=True)
Estimated max memory: [Memory Error]
True max torch.cuda.memory_allocated: [Memory Error]
nvidia-smi memory after run: 11427MiB / 11441MiB
Observation: Practically, we don’t see the memory savings that AMP should enable - with 50% memory savings, we should be able to double the non-AMP batch size from run 3 to 256000.

The erratic behavior can be summarized as follows:
1 - For small batches without AMP, it looks like very little memory is cached, but for large batches without AMP, almost all memory is cached (based on nvidia-smi run 1 vs run 3).
2 - For small batches with AMP, torch.cuda.memory_allocated shows an increase in memory usage, but for large batches with AMP, torch.cuda.memory_allocated shows the expected memory savings of 1/2 the forward pass (based on torch.cuda.memory_allocated run 2 vs run 4).
3 - For large batches, torch.cuda.memory_allocated shows the expected memory savings of 1/2 but nvidia-smi shows no memory savings (based on run 4 torch.cuda.memory_allocated and nvidia-smi). My theory of why this happens is that the forward pass is first computed in fp32, then cast to fp16, resulting in a large amount of cached memory. I looked at the source code for autocast and it was a bit opaque - I don’t see where it explicitly casts to fp16.

All of these results can be replicated with the shared Colab notebook.

How should I expect AMP to affect memory usage? Does CUDA caching restrict available memory? I’m particularly curious about run 4, where torch.cuda.memory_allocated shows the expected AMP memory savings, but nvidia-smi shows no savings and I get a memory error. Most importantly, is it possible in PyTorch to save the expected 50% of memory on the forward pass when using AMP, like you can in Apex?

@ptrblck calling in the big guns on this one

@ptrblck do you have any insight on why I am unable to achieve the expected 50% memory savings with AMP? Or a resource you could point me to where I can read up on why AMP doesn’t provide the expected memory savings in pytorch?