Estimate memory taken by network for one epoch


I would like to have a method to compute/estimate the amount of memory that would be used by the network during one epoch. I am not considering the size of the inputs here, only the memory taken by the network and its internal parameters.

I know the network will computes the gradients of the parameters during the backward pass, and I think it is straightforward to compute the memory taken by the gradients.

But I guess there are several things computed and stored during the backward and forward passes of the training. But to be honest I have no idea what is stored and takes memory.

Could you advise me a strategy for this?



from torchsummary import summary

summary(net, (3, 32, 32))

Here net(torch.nn) is your particular network and replace (3,32,32) with your image size. This gives you estimated parameter size as well as the total size. There may be other metrics. This may help you.

1 Like