DDP no_sync causes increased memory usage

My application uses AMP, DDP, and gradient accumulation. I’m trying to use no_sync to optimize the training loop. However, when I used no_sync it causes an extra ~2GB of VRAM usage and ultimately OOM after the first optimizer step. Without no_sync it runs fine (albeit inefficiently). Here is a code snippet:

#with nullcontext():
with self.unet.no_sync() if not is_last_device_step else nullcontext():
	# Forward pass
	loss = self.get_model_pred(batch)
	loss = loss / self.gradient_accumulation_steps

	# Backward pass

When I just use nullcontext, memory usage before the optimizer step is 18186MB as reported by nvidia-smi. When I use no_sync() memory use is 20262MB.

Based on my understanding of no_sync, the only difference should be disabling reducing the gradients. So I don’t see why it would cause an increase of memory use? Is it keeping a second copy of the gradients for some reason?

For testing purposes I’m currently distributed across only a single GPU.

PyTorch 2.0.1, nVidia 3090, driver 525.105.17, CUDA 12.0.