I observed that there was not only one (weight) but three times (weight + gradient * 2) memory usage during forward, and my conclusion at that time was that it was caused by (1) the momentum mechanism of the optimizer that requires the gradients to be stored from the previous training loop (previous batch), and (2) the fact that PyTorch might have allocated all .grad
s of a nn.Module
after the instance is created. Based on your answer here, is my conclusion completely wrong? I also believe that this has something todo with the ordering of code I mentioned in another thread yesterday.
Could you help me confirm this, as I might need to report this back to my teammate?
Thanks for your reading.