Why is this simple model becoming so enormous during training?

I reproduced a light and efficient UNet model from a paper, called HalfUNet, my code here:

When I build it with the default parameters such as this:
model = HalfUNet(in_channels=1, out_channels=2, feature_channels=64, depth=5)

my stats shows that this model has 593,795 parameters.
Considering each parameter is a float32, that means it takes only 4MB of memory. And that’s indeed what CUDA reports when I load it in memory.

I can infer tensors shaped as (128,1,256,256) with no issues at all.

But as soon as I want to train it (using the same tensor size) all the sudden it inflates to enormous proportions, taking more than 20GB ! It crashes both my Apple M1 MPS and my RTX 3090 (which has 24GB VRAM !), I can’t even train it on one single batch.

I can see that on my mac by tracking memory usage, and I can see that as well on my RTX 3090 also by seeing the sudden increase in memory until it displays this error message:

RuntimeError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 24.00 GiB total capacity; 21.72 GiB already allocated; 937.09 MiB free; 21.74 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

What’s the issue ? What node or connection makes it go wild like that ?

AFAIK the only difference between inference mode and training mode in my code is:

with torch.set_grad_enabled(False):
                model.eval()
with torch.set_grad_enabled(True):
                model.train()

Here’s the error message I get with my RTX 3090 when training the model as a torchscript:

Traceback (most recent call last):
  File \"C:\\Users\\divide\\TorchStudio\\python\\lib\\runpy.py\", line 196, in _run_module_as_main
      return _run_code(code, main_globals, None,
  File \"C:\\Users\\divide\\TorchStudio\\python\\lib\\runpy.py\", line 86, in _run_code
    exec(code, run_globals)
      File \"C:\\Users\\divide\\TorchStudio\\torchstudio\\modeltrain.py\", line 307, in <module>
          outputs = model(*inputs)
  File \"C:\\Users\\divide\\TorchStudio\\python\\lib\\site-packages\\torch\\nn\\modules\\module.py\", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File \"code/__torch__/modelmodule.py\", line 39, in forward
    _8 = (_1).forward(_7, )
    _9 = (_20).forward(_7, )
    _11 = (_2).forward(_9, )
           ~~~~~~~~~~~ <--- HERE
    _12 = (_30).forward(_9, )
    _13 = (_3).forward(_12, )
  File \"code/__torch__/torch/nn/modules/upsampling/___torch_mangle_54.py\", line 8, in forward
  def forward(self: __torch__.torch.nn.modules.upsampling.___torch_mangle_54.Upsample,
    argument_1: Tensor) -> Tensor:
    _0 = torch.upsample_bilinear2d(argument_1, None, False, [4., 4.])
         ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    return _0

Traceback of TorchScript, original code (most recent call last):
/Users/divide/TorchStudio/python/lib/python3.9/site-packages/torch/nn/functional.py(3950): interpolate
/Users/divide/TorchStudio/python/lib/python3.9/site-packages/torch/nn/modules/upsampling.py(156): forward
/Users/divide/TorchStudio/python/lib/python3.9/site-packages/torch/nn/modules/module.py(1182): _slow_forward
/Users/divide/TorchStudio/python/lib/python3.9/site-packages/torch/nn/modules/module.py(1194): _call_impl
<string>(70): forward
/Users/divide/TorchStudio/python/lib/python3.9/site-packages/torch/nn/modules/module.py(1182): _slow_forward
/Users/divide/TorchStudio/python/lib/python3.9/site-packages/torch/nn/modules/module.py(1194): _call_impl
/Users/divide/TorchStudio/python/lib/python3.9/site-packages/torch/jit/_trace.py(976): trace_module
/Users/divide/TorchStudio/python/lib/python3.9/site-packages/torch/jit/_trace.py(759): trace
/Users/divide/TorchStudio/torchstudio/modules.py(63): safe_exec
/Users/divide/TorchStudio/torchstudio/modelbuild.py(108): <module>
/Users/divide/TorchStudio/python/lib/python3.9/runpy.py(87): _run_code
/Users/divide/TorchStudio/python/lib/python3.9/runpy.py(197): _run_module_as_main
RuntimeError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 24.00 GiB total capacity; 21.72 GiB already allocated; 937.09 MiB free; 21.74 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

and here’s the error message I get with my RTX 3090 when training the model as a torch.package:

Traceback (most recent call last):
  File \"C:\\Users\\divide\\TorchStudio\\python\\lib\\runpy.py\", line 196, in _run_module_as_main
      return _run_code(code, main_globals, None,
        File \"C:\\Users\\divide\\TorchStudio\\python\\lib\\runpy.py\", line 86, in _run_code
            exec(code, run_globals)
              File \"C:\\Users\\divide\\TorchStudio\\torchstudio\\modeltrain.py\", line 307, in <module>
                  outputs = model(*inputs)
                    File \"C:\\Users\\divide\\TorchStudio\\python\\lib\\site-packages\\torch\\nn\\modules\\module.py\", line 1194, in _call_impl
                        return forward_call(*input, **kwargs)
                          File \"<torch_package_0>.modelmodule.py\", line 79, in forward
                            File \"C:\\Users\\divide\\TorchStudio\\python\\lib\\site-packages\\torch\\nn\\modules\\module.py\", line 1194, in _call_impl
                                return forward_call(*input, **kwargs)
                                  File \"C:\\Users\\divide\\TorchStudio\\python\\lib\\site-packages\\torch\\nn\\modules\\container.py\", line 204, in forward
                                      input = module(input)
                                        File \"C:\\Users\\divide\\TorchStudio\\python\\lib\\site-packages\\torch\\nn\\modules\\module.py\", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
      File \"C:\\Users\\divide\\TorchStudio\\python\\lib\\site-packages\\torch\\nn\\modules\\conv.py\", line 463, in forward
          return self._conv_forward(input, self.weight, self.bias)
            File \"C:\\Users\\divide\\TorchStudio\\python\\lib\\site-packages\\torch\\nn\\modules\\conv.py\", line 459, in _conv_forward
                return F.conv2d(input, weight, bias, self.stride,
                torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.59 GiB (GPU 0; 24.00 GiB total capacity; 22.38 GiB already allocated; 255.96 MiB free; 22.40 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

The parameters could use a small fraction of the overall memory footprint depending on the model architecture.
E.g. conv layers have usually very few parameters (only the often small kernel and bias) while the output activation might be huge (which is needed for the gradient computation and is stored if Autograd is enabled). This post describes it in more detail for a ResNet architecture and you can already see the effect in this small code snippet:

x = torch.randn(1, 3, 224, 224)
print(x.nelement() * x.element_size())
# 602112

conv = nn.Conv2d(3, 64, 3, 1, 1)
params = sum([p.nelement() for p in conv.parameters()])
print(params * conv.weight.element_size())
# 7168

out = conv(x)
print(out.nelement() * out.element_size())
# 12845056

12845056 / 7168
# 1792

As you can see the output activation uses ~1700x more memory than the parameters of the conv layer.

That’s right - but as I said, nothing crazy happens during inference, it remains within a very reasonable range. The entire architecture itself and how tensor sizes change from one node to the next can be seen in the image below, as you’ll see it’s quite simple and straightforward.

Something is leaking memory like hell only when gradient is activated.

During inference the activations are not stored for the entire forward pass but only to compute the next activation. Afterwards they are freed which reduces the overall memory footprint.
What makes you think you are leaking memory?

After paying closer attention to the architecture, I see that at some point 5 tensors are added, each with a size of 128x64x256x256. If the backprop phase needs to have all those tensors available at once, that’s indeed already 10GB of memory there. And I could imagine that it needs a duplicate of those for some calculation, which goes to 20GB already.

Well I guess that would be the reason then…

Yes, this could explain it. Your model architecture is mostly using conv, pooling, activation layers, which all could create large output shapes (depending on the actual input shape).
Some CNNs (such as ResNet) reduce the spatial size to decrease the memory footprint via pooling layers etc., but it seems you are also upsampling the activations again (i.e. increasing their spatial size).
To avoid storing these activartions and to recompute them instead during the backward pass, you could check torch.utils.checkpoint.checkpoint.

1 Like

Thanks for the insights and tips.
And also, thanks for the tremendous support you’re doing here over the years, it’s a real pleasure to have someone as responsive and as educational as you !

1 Like

Sure, happy to help and thanks for the kind words. Let me know if you get stuck using checkpointing (it might not be trivial especially for a UNet) or if some memory usage is not explainable (in which case I would need to check the model architecture and debug it).