Yea, it works now. Not sure if the issue was with varying validation set or the state_dict was not saved properly.
Maybe the issue was that my validation set was varying (i.e, every epoch it generated a new set).
I fixed my validation set to a fixed number of samples and I tried again and it worked properly.
Thank You