I would like to have a method to compute/estimate the amount of memory that would be used by the network during one epoch. I am not considering the size of the inputs here, only the memory taken by the network and its internal parameters.
I know the network will computes the gradients of the parameters during the backward pass, and I think it is straightforward to compute the memory taken by the gradients.
But I guess there are several things computed and stored during the backward and forward passes of the training. But to be honest I have no idea what is stored and takes memory.
Could you advise me a strategy for this?