VRAM usage increase with more gradient accumulation steps

What would cause GPU memory usage increase with more gradient accumulation steps? Essentially inside the following loop

# outer loop
	optimizer.zero_grad(set_to_none=True)
	for micro_step in range(gradient_accumulation_steps):
	    batch = next(train_data_loader_iterator)
	    no_sync_enabled = micro_step < gradient_accumulation_steps - 1
	    with no_backward_sync_ctx(enabled=no_sync_enabled): # no-op for a single device
	        loss = model(**batch)
	        loss /= gradient_accumulation_steps
	        loss.backward(loss)
	        total_loss += loss.detach()
	optimizer.step()

I am assuming changing gradient_accumulation_steps shall not have any impact on the memory usage. Gradient accumulation shouldn’t create any new variables and all the parameter-specific updates should be accumulated in-place (e.g. update += new_update)?

In practice, however, while training the model on a single GPU, multiple GPUs with DDP, or FSDP - everywhere I observe higher VRAM usage for number of steps more than 1. Any idea whether it is expected? And if not - how to debug this “leak”?

Thanks!

2 Likes

In theory, this is expected behavior if we consider that peak memory usage in a typical forward, then backward model execution occurs just before the backward pass as at this stage forward activations are being kept alive in preparation for the backward pass.

At this point, in the regime without gradient accumulation (gradient_accumulation_steps=1), the model parameters do not have gradient tensors because optimizer.zero_grad(set_to_none=True) was called, so the gradient tensors are None rather than only being zero’d out. However, with gradient_accumulation_steps>1, these gradient tensors are kept around at this stage for, well, gradient accumulation so the high watermark of memory usage would be higher.

In other words, for the first iteration of for micro_step in range(gradient_accumulation_steps):, the gradient fields of the parameters are not populated yet, but for later iterations, they will be, increasing the maximum total memory usage at this point. The very first accumulation is not in-place in this sense.

However, if you are observing a trend of increasing memory usage when increasing the number of gradient accumulation steps beyond 2, then this could be unexpected and indicative of a leak.

3 Likes