How to estimate necessary (GPU) memory for a model

I have a model that throws not enough memory on GPU error upon training (coming from .forward()).

File "/home/.../.../run/", line 131, in main
  output = model(data)
File "/home/.../anaconda3/envs/.../lib/python3.7/site-packages/torch/nn/modules/", line 541, in __call__
  result = self.forward(*input, **kwargs)
File "/home/.../.../model/", line 47, in forward
  x = self.layers[f"conv_{i}"](x)
File "/home/.../anaconda3/envs/.../lib/python3.7/site-packages/torch/nn/modules/", line 541, in __call__
  result = self.forward(*input, **kwargs)
File "/home/.../anaconda3/envs/.../lib/python3.7/site-packages/torch/nn/modules/", line 345, in forward
  return self.conv2d_forward(input, self.weight)
File "/home/.../anaconda3/envs/.../lib/python3.7/site-packages/torch/nn/modules/", line 342, in conv2d_forward
  self.padding, self.dilation, self.groups)
RuntimeError: CUDA out of memory. Tried to allocate 14.00 MiB (GPU 0; 7.93 GiB total capacity; 7.35 GiB already allocated; 12.75 MiB free; 87.83 MiB cached)

When I checked the model size based on parameters it definitely fits into the memory and each batch size is also quite small that these cannot be the source of the exception.

(I was told from my friend that it might be coming from the activation tensors that are stored for the backward pass.)

In this post, I would like to ask the following questions.

  1. What does PyTorch allocate memory for other than model and data (especially during the training process)? I would like to know the exact cause of the exception.
  2. Is there any way of estimating how much memory the model requires “prior to training” and “programmatically”?

When I use term memory, it can simply be the number of float (tensor) because I can always estimate one metric from the other

Thank you for your time

I am not sure about the inbuilt memory usage by pytorch library. There is one library called pytorch_memlab, which can be used to inspect the GPU memory usage by each line of your code.

It looks like I have to call backward() to check the usage though…