This issue is closed. It was just a bug when initializing the instances.
We created a model with 15 layers, FP32, each layer size 12K (so weights are 4x12Kx12K = 576MB). So, the model size itself is 576*15= 8.6GB and there are activations, input layers etc. The size of the model alone cannot fit into an 8GB GPU.
Yet Pytorch runs fine on the GPU using only about 4.7GB while training (checking nvidia smi or torch memory allocated). We also printed the model.state_dict() and it shows the layers as we defined them.
We also noticed that the CPU process is spiking on virtual memory going up to 50 GB. Any explanations? Thanks.