Thanks for this. So the backward call frees the buffers for the next batch. As we skip the backward call this will raise the memory usage one iteration basically, leading to an OOM if no extra GPU space is available. Is there a way to manually clear those buffers without calling backward?
From your answer on this post, the only option I see is getting the last tensor in the graph before the error occurred (somehow) and deleting it (or calling backward on it and setting al gradients to zero).