Memory (RAM) usage keep going up every step

Are you storing some variables which were not detached from the computation graph, e.g. the loss in a list?
This would increase the memory usage in each iteration so I’m currently a bit lost why your code seems to run well most of the time.
Also, could you share a (small) code snippet so that we could have a look?