Training two different CNN architectures (small vs big) take 20GB of vram while training

One is a super simple resnet autoencoder. 8,3,256,256 input, total of 1.7MB in the .pt file.
The other is a UNET, same batch/tensor size and a lot more layers, 17MB in the .pt file.
Both use depthwise convolutions instead of straight conv2d(), gelu and layer norm.

When training, both of them consume 20GB of vram as soon as it runs the model for the first time. Thankfully I have an A5000 with 24GB, but this seems excessive, and also odd, given the model size difference.

I’ve stepped through and noticed the ‘encoder’ half of the network burns about 8 GB on first invocation, and another 8-9 on the decoder half for the resnet AE. After this memory doesn’t really go up or down, it stays pegged at 20GB. I’ve reduced worker count, as I’ve seen that impact it before, but no luck here.

I’ve tried using the auto casting /gradscaler route and that doesn’t do anything for memory or speed.

I understand that It would have to store gradients per parameter, and then the tensors have to flow down a pipeline, but this still seems excessive. I also see various examples where there are very large batch sizes from years ago on much less capable GPUs and I’m wondering if I’ve done something wrong somewhere.

I’ve also seen advice here on the forums that the GC might be hanging onto stuff longer than it should, and after adding a bunch of ‘detach()’ and even some ‘del’ there was no impact on the memory.

Asking Bard, Chat-GPT and general googling all have been fruitless. I’ve tried everything suggested and it hasn’t impacted memory usage.

I also tried batch sizes of ‘1’ and my simple resnet will eat about 2GB of vram, and it does seem to scale linearly from there, ending up at around 18-19GB when back at 8 batches.

So, looking for pointers on how to rein in the memory usage, or at least someone to tell me “seems about the same on my computer”.
Thanks Pytorch forums!

You are not accounting the intermediate activations, which might use significantly more memory than the parameters and also fits:

This post estimates the memory usage from parameters and forward activations.

I’ll got look at the post. Thanks.

Hmm, that OP has a much larger network(resnet50, I have more of a ‘resnet-9’), has a batch size of 96 and only gets to 10GB. Meanwhile, I can barely fit 8. So I guess something is amiss.

I’ll try adding the instrumentation from your example. I also see referenced posts and will also look at those. I appreciate the response.

Looks like there is also a large difference in memory usage between linux and windows. I was using windows pytorch 2.0. I just tried WSL, with ubuntu 22, pytorch 2.0, and also used torch.compile() using the default inductor backend. It used much less memory and as almost twice as fast per epoch, as I could increase the batch per item count to 16 easily.

I still haven’t understood where all the memory disappears too, but training in WSL isn’t a huge burden, so at least I can leverage that.