In case someone is interested. The following helped me:
I had to detach and preferably empty all tensors used inside the for loops. I am not adding the code here. But for the for loops inside epoch loop, I have detached and zeroed tensors, before they are being computed and saved