I ran the code with model.eval(). After the last batch norm layer all the values are becoming Nan.However when I set model.train() the code is working fine.
Further I saw, the weights and bias parameters of the Batch Norm layers are going to Nan. This seems quite a weird behaviour.
At the following threads also people are trying to do the exact same thing. I followed all the solutions mentioned there but still no luck.