OOM: Multi-label classification

I have a similar but more generic problem: How to profile memory in Pytorch

From my experience, what you see in nvidia-smi is current memory. There might be some part of a computation where memory jumps, and quickly freed and you might not see it in nvidia-smi due to low update rate.