Rank 0 always running out of GPU memory

Following the guide here, I set up a distributed training on my devbox with 4 GPUs (12 GB each). I however notice that after a few steps, the memory in GPU0 always max outs and the corresponding rank dies. The remaining ranks continue to work though. This is happening even with very small batch sizes (for e.g. 16). Any ideas on what could be happening? The model seems small enough to easily work with 12 GB. The input tensors are not too large either.

Happy to provide more specific information. Debugging tips are welcome.