Another CUDA out of memory issue

I’ve been trying to build a 2.5D UNet model to produce segmentations from CT scan DICOM files. In short, I want to train a series of N * (512, 512) images against N * slices from a volume of segmentations (it’s a NIFTI file).

I have had all sorts of problems to make this work on my GPU (Windows 11 environment, RTX4080 16GB).

Initially, the model would not train if I was setting pin_memory=True: I would get a “CUDA out of memory” error message after the first or second batch. With pin_memory=False, I can train the model through the first epoch but, when the validation step begins, I obtain another “CUDA out of memory” error message:

OutOfMemoryError: CUDA out of memory. Tried to allocate 9.19 GiB (GPU 0; 15.99 GiB total capacity; 10.23 GiB already allocated; 674.34 MiB free; 12.93 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Below is the function I am trying to use to do a simple accuracy check (I’m also using Dice loss as my loss function):

def check_accuracy(loader, model, device='cuda'):
    model.eval()
    num_correct = 0
    num_pixels = 0
    
    with torch.no_grad():
        for X, y in loader:
            for i in range(X.size(0)):
                print('Before data allocation', torch.cuda.memory_allocated())
                X_i = X[i].squeeze(0).float().to(device)
                y_i = y[i].squeeze(0).int().to(device)
                print('After data allocation', torch.cuda.memory_allocated())
                preds = torch.sigmoid(model(X_i))
                print('After evaluation', torch.cuda.memory_allocated())
                preds = (preds > 0.5).float()
                num_correct += (preds == y_i).sum()
                num_pixels += torch.numel(preds)
                del preds
        del X
        del y
        gc.collect()
        torch.cuda.empty_cache()
    print(f'Got {num_correct} with acc {num_correct/num_pixels*100:.2f}' )
    model.train()

And here are the results of the first two print statements before CUDA issues its error message when trying to use the model:

Before data allocation 503097344
After data allocation 1119660032

As you can see, only 0.5GB was used before the data (X_i, y_i) was allocated. After allocation, the memory use jumps to 1.1GB. I estimate that the model itself only takes 0.13GB. Why on earth does pytorch try to allocate 9.19GB to the GPU? I have tried to find answers everywhere (here, on Stack Overflow, etc…), but can’t find any solutions or any similar issues to what I have encountered. I have included deletion of obsolete objects, CUDA cache emptying, garbage collection, the train-eval-train switch, use of no_grad on evaluation: I’m running out of solutions… I also tried setting the max_split_size_mb value to 256, 512, 1024: same result.

Questions:
1- How does PyTorch determine that it wants to allocate 9.19GB to the GPU, when in reality the data volumes and/or the model are much smaller than that?
2- I have often encountered problems with Windows compatibility across several packages. Could it be another one of those isses?
3- Why the pin_memory=True does not appear to work when the use case suggest that’s what I should be doing?
4- More importantly, how do I fix my problem?

1 Like

A couple of things you can try to maximize your available memory:

  1. Use dtype=torch.bfloat16. What Every User Should Know About Mixed Precision Training in PyTorch | PyTorch
  2. Reduce your batch size and accumulate gradients. CUDA Automatic Mixed Precision examples — PyTorch 2.0 documentation
  3. Reduce the model size.

Note that when you instantiate the model class, weights do not get initialized until your first data enters the model. Additionally, your optimizer will be storing easily 3-4x the size of your model in gradients and other values needed for that optimizer to work properly.

A few more points to mention, make sure you’re using optimizer.zero_grad() between optimizer steps.

And to be on the safe side, always loss.detach() and outputs.detach() before calculating any stats.

Thanks for your reply. I’ll address each of your points:

1- I was already using torch.cuda.amp.GradScaler() and torch.cuda.amp.autocast(). In fact, my code was almost a carbon copy of the code snippet featured in the link you provided. I even tried explicitly to set the dtype (I think autocast() by default sets it to torch.float16), and it yielded the same result: I can train for one epoch then CUDA goes out of memory in eval() mode.

2- Again, the issue is not so much with training. I have been able to train one full epoch with batch sizes of 1, 2 , 4, 8 and 16, the latter being most likely the limit for my GPU. When I raised this issue, I was using batches of 4. You will also notice that the function I pasted above would run on a batch size of 1. Still, CUDA goes out of memory by asking an allocation that is even bigger than the actual size of my validation sample (I have had various figures for the size being attempted for allocation, from as little as 1GB up more than 50GB; there is something that PyTorch does which does not make sense, and I can’t put the finger on it).

3- The model is a simple UNet with 3 downsampling blocks (Double ConvNets with BatchNorm and ReLU), a bottleneck and 3 upsampling blocks. When I run with batches of 1, the model is approximately 125MB. If I run it on bigger batches, it goes up to ~360MB. When you use the model for validation or for inference, it should normally just use the size of the model, at least based on this article I have read to try and troubleshoot the problem.

4- About the optimizer: following cues from the same article, I tried using SGD (no moments calculated, so no additional memory requirements) and Adam (2 moments calculated, so at least 2X my model size in additional memory), and it leads me to the same result: can train, can’t validate/infer.

5- optimizer.zero_grad(): it was already in my code.

6- loss/outputs.detach(): I added these, but it didn’t change the result.

So, despite your good advice, I’m still stuck with the issues I have outlined above.

1 Like

Would you mind sharing your code and the entire error showing which line of code caused the issue?

Here is the full error message (it’s long). In this example, the code trips on BatchNorm when it tries to execute preds = torch.sigmoid(model(X_i)) in the check_accuracy function I posted previously; I have also seen it trip on Conv2d:

OutOfMemoryError                          Traceback (most recent call last)
d:\rsna-2023-abdominal-trauma-detection\test.py in line 5
      <a href='file:///d%3A/rsna-2023-abdominal-trauma-detection/test.py?line=0'>1</a> import train
      <a href='file:///d%3A/rsna-2023-abdominal-trauma-detection/test.py?line=2'>3</a> if __name__ == '__main__':
----> <a href='file:///d%3A/rsna-2023-abdominal-trauma-detection/test.py?line=4'>5</a>     train.main()

File d:\rsna-2023-abdominal-trauma-detection\train.py:116, in main()
    <a href='file:///d%3A/rsna-2023-abdominal-trauma-detection/train.py?line=112'>113</a> save_checkpoint(checkpoint)
    <a href='file:///d%3A/rsna-2023-abdominal-trauma-detection/train.py?line=114'>115</a> # check accuracy/DICE
--> <a href='file:///d%3A/rsna-2023-abdominal-trauma-detection/train.py?line=115'>116</a> check_accuracy(val_loader, model, device=config.DEVICE)
    <a href='file:///d%3A/rsna-2023-abdominal-trauma-detection/train.py?line=117'>118</a> save_predictions_as_imgs(val_loader, model, folder='saved_images/', device=config.DEVICE)

File d:\rsna-2023-abdominal-trauma-detection\utils.py:89, in check_accuracy(loader, model, device)
     <a href='file:///d%3A/rsna-2023-abdominal-trauma-detection/utils.py?line=86'>87</a> y_i = y.squeeze(0).int().to(device)
     <a href='file:///d%3A/rsna-2023-abdominal-trauma-detection/utils.py?line=87'>88</a> print('After data allocation', torch.cuda.memory_allocated())
---> <a href='file:///d%3A/rsna-2023-abdominal-trauma-detection/utils.py?line=88'>89</a> preds = torch.sigmoid(model(X_i))
     <a href='file:///d%3A/rsna-2023-abdominal-trauma-detection/utils.py?line=89'>90</a> print('After evaluation', torch.cuda.memory_allocated())
     <a href='file:///d%3A/rsna-2023-abdominal-trauma-detection/utils.py?line=90'>91</a> preds = (preds > 0.5).float()

File c:\Users\probi\anaconda3\envs\kaggle\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1495'>1496</a> # If we don't have any hooks, we want to skip the rest of the logic in
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1496'>1497</a> # this function, and just call forward.
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1497'>1498</a> if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1498'>1499</a>         or _global_backward_pre_hooks or _global_backward_hooks
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1499'>1500</a>         or _global_forward_hooks or _global_forward_pre_hooks):
-> <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1500'>1501</a>     return forward_call(*args, **kwargs)
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1501'>1502</a> # Do not call functions when jit is used
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1502'>1503</a> full_backward_hooks, non_full_backward_hooks = [], []

File d:\rsna-2023-abdominal-trauma-detection\models.py:43, in UNET25D.forward(self, x)
     <a href='file:///d%3A/rsna-2023-abdominal-trauma-detection/models.py?line=40'>41</a> skip_connections = []
     <a href='file:///d%3A/rsna-2023-abdominal-trauma-detection/models.py?line=41'>42</a> for down_block in self.downsampling:
---> <a href='file:///d%3A/rsna-2023-abdominal-trauma-detection/models.py?line=42'>43</a>     x = down_block(x)
     <a href='file:///d%3A/rsna-2023-abdominal-trauma-detection/models.py?line=43'>44</a>     skip_connections.append(x)
     <a href='file:///d%3A/rsna-2023-abdominal-trauma-detection/models.py?line=44'>45</a>     x = self.pooling(x)

File c:\Users\probi\anaconda3\envs\kaggle\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1495'>1496</a> # If we don't have any hooks, we want to skip the rest of the logic in
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1496'>1497</a> # this function, and just call forward.
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1497'>1498</a> if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1498'>1499</a>         or _global_backward_pre_hooks or _global_backward_hooks
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1499'>1500</a>         or _global_forward_hooks or _global_forward_pre_hooks):
-> <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1500'>1501</a>     return forward_call(*args, **kwargs)
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1501'>1502</a> # Do not call functions when jit is used
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1502'>1503</a> full_backward_hooks, non_full_backward_hooks = [], []

File d:\rsna-2023-abdominal-trauma-detection\models.py:16, in DoubleConv.forward(self, x)
     <a href='file:///d%3A/rsna-2023-abdominal-trauma-detection/models.py?line=14'>15</a> def forward(self, x):
---> <a href='file:///d%3A/rsna-2023-abdominal-trauma-detection/models.py?line=15'>16</a>     return self.conv_block(x)

File c:\Users\probi\anaconda3\envs\kaggle\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1495'>1496</a> # If we don't have any hooks, we want to skip the rest of the logic in
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1496'>1497</a> # this function, and just call forward.
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1497'>1498</a> if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1498'>1499</a>         or _global_backward_pre_hooks or _global_backward_hooks
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1499'>1500</a>         or _global_forward_hooks or _global_forward_pre_hooks):
-> <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1500'>1501</a>     return forward_call(*args, **kwargs)
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1501'>1502</a> # Do not call functions when jit is used
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1502'>1503</a> full_backward_hooks, non_full_backward_hooks = [], []

File c:\Users\probi\anaconda3\envs\kaggle\lib\site-packages\torch\nn\modules\container.py:217, in Sequential.forward(self, input)
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/container.py?line=214'>215</a> def forward(self, input):
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/container.py?line=215'>216</a>     for module in self:
--> <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/container.py?line=216'>217</a>         input = module(input)
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/container.py?line=217'>218</a>     return input

File c:\Users\probi\anaconda3\envs\kaggle\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1495'>1496</a> # If we don't have any hooks, we want to skip the rest of the logic in
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1496'>1497</a> # this function, and just call forward.
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1497'>1498</a> if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1498'>1499</a>         or _global_backward_pre_hooks or _global_backward_hooks
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1499'>1500</a>         or _global_forward_hooks or _global_forward_pre_hooks):
-> <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1500'>1501</a>     return forward_call(*args, **kwargs)
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1501'>1502</a> # Do not call functions when jit is used
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/module.py?line=1502'>1503</a> full_backward_hooks, non_full_backward_hooks = [], []

File c:\Users\probi\anaconda3\envs\kaggle\lib\site-packages\torch\nn\modules\batchnorm.py:171, in _BatchNorm.forward(self, input)
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/batchnorm.py?line=163'>164</a>     bn_training = (self.running_mean is None) and (self.running_var is None)
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/batchnorm.py?line=165'>166</a> r"""
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/batchnorm.py?line=166'>167</a> Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/batchnorm.py?line=167'>168</a> passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/batchnorm.py?line=168'>169</a> used for normalization (i.e. in eval mode when buffers are not None).
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/batchnorm.py?line=169'>170</a> """
--> <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/batchnorm.py?line=170'>171</a> return F.batch_norm(
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/batchnorm.py?line=171'>172</a>     input,
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/batchnorm.py?line=172'>173</a>     # If buffers are not to be tracked, ensure that they won't be updated
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/batchnorm.py?line=173'>174</a>     self.running_mean
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/batchnorm.py?line=174'>175</a>     if not self.training or self.track_running_stats
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/batchnorm.py?line=175'>176</a>     else None,
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/batchnorm.py?line=176'>177</a>     self.running_var if not self.training or self.track_running_stats else None,
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/batchnorm.py?line=177'>178</a>     self.weight,
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/batchnorm.py?line=178'>179</a>     self.bias,
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/batchnorm.py?line=179'>180</a>     bn_training,
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/batchnorm.py?line=180'>181</a>     exponential_average_factor,
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/batchnorm.py?line=181'>182</a>     self.eps,
    <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/modules/batchnorm.py?line=182'>183</a> )

File c:\Users\probi\anaconda3\envs\kaggle\lib\site-packages\torch\nn\functional.py:2450, in batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/functional.py?line=2446'>2447</a> if training:
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/functional.py?line=2447'>2448</a>     _verify_batch_size(input.size())
-> <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/functional.py?line=2449'>2450</a> return torch.batch_norm(
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/functional.py?line=2450'>2451</a>     input, weight, bias, running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled
   <a href='file:///c%3A/Users/probi/anaconda3/envs/kaggle/lib/site-packages/torch/nn/functional.py?line=2451'>2452</a> )

OutOfMemoryError: CUDA out of memory. Tried to allocate 9.00 GiB (GPU 0; 15.99 GiB total capacity; 10.03 GiB already allocated; 3.23 GiB free; 10.38 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

The implementation of UNet:

import torch
import torch.nn as nn

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True))

    def forward(self, x):
        return self.conv_block(x)

class UNET25D(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNET25D, self).__init__()
        self.upsampling = nn.ModuleList()
        self.downsampling = nn.ModuleList()
        self.pooling = nn.MaxPool2d(kernel_size=2, stride=2)

        # Downsampling
        for feature in features:
            self.downsampling.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Upsampling
        for feature in reversed(features):
            self.upsampling.append(
                nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)
            )
            self.upsampling.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for down_block in self.downsampling:
            x = down_block(x)
            skip_connections.append(x)
            x = self.pooling(x)
        
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for index in range(0, len(self.upsampling), 2):
            x = self.upsampling[index](x)
            skip_connection = skip_connections[index//2]
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.upsampling[index + 1](concat_skip)
        
        return self.final_conv(x)

Finally, my train function:

def train(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)
    
    for batch_idx, (data, targets) in enumerate(loop):
        print('Batch ', batch_idx)
        data = data.squeeze(0)
        targets = targets.squeeze(0)
        BATCH_SIZE = 4
        mini_batches = math.ceil(data.size(0) / BATCH_SIZE)
        for i in range(mini_batches):
            mb_from = i * BATCH_SIZE
            mb_to = min((i + 1) * BATCH_SIZE, data.size(0))
            X = data[mb_from:mb_to].float().to(device=config.DEVICE, non_blocking=True)
            y = targets[mb_from:mb_to].int().to(device=config.DEVICE, non_blocking=True)
            
            # Forward prop
            with torch.cuda.amp.autocast(dtype=torch.float16):
                predictions = model(X)
                loss = loss_fn(predictions, y)
                del X
                del y
                gc.collect()
                torch.cuda.empty_cache()
            
            # Backward prop
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        # Update tqdm
        loop.set_postfix(loss=loss.item())
        loss.detach()

I’m not seeing where you are calling the check_accuracy function in your code. Anyway, using just your model and check_accuracy function with random inputs, I’m not able to reproduce the issue.

Before data allocation 125523456
After data allocation 125523456
After evaluation 125539840

Can you make a minimal but complete code that reproduces the problem?

Here is the main piece of code (with some edits). The train_sample_list and val_sample_list are lists of tuples to be used in conjunction with the img_path and seg_path to populate and load the dataset. Also, as mentioned previously, pin_memory does not work for me: I get CUDA OOM errors during training when I set it to True.

def main():
    train_transforms = Compose(
        [Normalize(
             mean=[0.0, 0.0, 0.0],
             std=[1.0, 1.0, 1.0],
             inplace=True
         )]
    )

    val_transforms = Compose(
        [Normalize(
             mean=[0.0],
             std=[1.0],
             inplace=True
         )]
    )

    model = UNET25D(in_channels=3, out_channels=1).to(config.DEVICE)
    loss_fn = GeneralizedDiceLoss(sigmoid=True, w_type='simple')
    optimizer = optim.Adam(model.parameters(), lr=config.INIT_LR)

    train_loader, val_loader = get_loaders(
        img_path,
        seg_path,
        train_sample_list,
        val_sample_list,
        1, # Batch size: one series of DICOM files/one segmentation volume
        train_transforms,
        val_transforms,
        window_center=50,
        window_width=400,
        extra_only=True,
        num_workers=4,
        pin_memory=False
    )

    if config.LOAD_MODEL:
        load_checkpoint(torch.load('checkpoint.pth.tar'), model)

    scaler = torch.cuda.amp.GradScaler()
    for epoch in range(config.NUM_EPOCHS):
        train(train_loader, model, optimizer, loss_fn, scaler)

        # save model
        checkpoint = {'state_dict': model.state_dict(),
                      'optimizer': optimizer.state_dict()}
        save_checkpoint(checkpoint)

        # check accuracy/DICE
        check_accuracy(val_loader, model, device=config.DEVICE)

        save_predictions_as_imgs(val_loader, model, folder='saved_images/', device=config.DEVICE)

When I say “complete code”, I mean something I can copy+paste in it’s entirety to a scratch for debugging and is self-contained which reproduces the error. So any classes or definitions called are contained in the code or are imported from the torch library.

1 Like

OK, here goes:

import torch
from tqdm import tqdm
import torch.optim as optim
from torchvision.transforms import Compose, Normalize
from models import UNET25D
from monai.losses.dice import GeneralizedDiceLoss
from utils import load_checkpoint, save_checkpoint, get_loaders, check_accuracy, save_predictions_as_imgs
import config
import pandas as pd
import numpy as np
import math
import gc
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:1024"

def train(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)
    
    for batch_idx, (data, targets) in enumerate(loop):
        print('Batch ', batch_idx)
        data = data.squeeze(0)
        targets = targets.squeeze(0)
        BATCH_SIZE = 4
        mini_batches = math.ceil(data.size(0) / BATCH_SIZE)
        for i in range(mini_batches):
            mb_from = i * BATCH_SIZE
            mb_to = min((i + 1) * BATCH_SIZE, data.size(0))
            X = data[mb_from:mb_to].float().to(device=config.DEVICE, non_blocking=True)
            y = targets[mb_from:mb_to].int().to(device=config.DEVICE, non_blocking=True)
            
            # Forward prop
            with torch.cuda.amp.autocast(dtype=torch.float16):
                predictions = model(X)
                if torch.sum(y) > 0:
                    for i in range(predictions.size(0)):
                        print('Predict', torch.sum(predictions[i]))
                        print('Target', torch.sum(y[i]))
                loss = loss_fn(predictions, y)
                del X
                del y
                gc.collect()
                torch.cuda.empty_cache()
            
            # Backward prop
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        # Update tqdm
        loop.set_postfix(loss=loss.item())
        loss.detach()

def main():
    train_transforms = Compose(
        [Normalize(
             mean=[0.0, 0.0, 0.0],
             std=[1.0, 1.0, 1.0],
             inplace=True
         )]
    )

    val_transforms = Compose(
        [Normalize(
             mean=[0.0],
             std=[1.0],
             inplace=True
         )]
    )

    all_series = pd.read_csv('D:\\sample_data.csv')
    healthy = all_series.loc[all_series['Sample'] == 'Healthy', ['patient_id', 'series_id']]
    injured = all_series.loc[all_series['Extra Mapped'] == 1, ['patient_id', 'series_id']]
    healthy_samples = np.array(list(zip(healthy['patient_id'], healthy['series_id'])))
    injured_samples = np.array(list(zip(injured['patient_id'], injured['series_id'])))
    healthy_select = healthy_samples[np.random.choice(len(healthy_samples), 36, replace=False)]
    injured_select = injured_samples[np.random.choice(len(injured_samples), 4, replace=False)]
    train_sample_list = np.concatenate((healthy_select[:28], injured_select[:2]))
    val_sample_list = np.concatenate((healthy_select[28:], injured_select[2:]))

    model = UNET25D(in_channels=3, out_channels=1).to(config.DEVICE)
    loss_fn = GeneralizedDiceLoss(sigmoid=True, w_type='simple')
    optimizer = optim.Adam(model.parameters(), lr=config.INIT_LR)

    train_loader, val_loader = get_loaders(
        img_path,
        seg_path,
        train_sample_list,
        val_sample_list,
        1, # Batch size: one series of DICOM files/one segmentation volume
        train_transforms,
        val_transforms,
        window_center=50,
        window_width=400,
        extra_only=True,
        num_workers=4
    )

    if config.LOAD_MODEL:
        load_checkpoint(torch.load('checkpoint.pth.tar'), model)

    scaler = torch.cuda.amp.GradScaler()
    for epoch in range(config.NUM_EPOCHS):
        train(train_loader, model, optimizer, loss_fn, scaler)

        # save model
        checkpoint = {'state_dict': model.state_dict(),
                      'optimizer': optimizer.state_dict()}
        save_checkpoint(checkpoint)

        check_accuracy(val_loader, model, device=config.DEVICE)
import torch
import torchvision
from dataset import SegmentationDataset
from torch.utils.data import DataLoader
import gc

def get_loaders(
        img_path,
        seg_path,
        train_sample_list,
        val_sample_list,
        batch_size,
        train_transform,
        val_transform,
        window_center,
        window_width,
        extra_only,
        num_workers=4,
        pin_memory=False
):
    train_dataset = SegmentationDataset(
        img_path = img_path,
        seg_path = seg_path,
        sample_list = train_sample_list, 
        center = window_center, 
        width = window_width, 
        extra_only = extra_only,
        img_transforms=train_transform,
        seg_transforms=val_transform
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True
    )

    val_dataset = SegmentationDataset(
        img_path = img_path,
        seg_path = seg_path,
        sample_list = val_sample_list, 
        center = window_center, 
        width = window_width, 
        extra_only = extra_only,
        img_transforms=train_transform,
        seg_transforms=val_transform
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False
    )

    return train_loader, val_loader

def check_accuracy(loader, model, device='cuda'):
    model.eval()
    num_correct = 0
    num_pixels = 0
    
    with torch.no_grad():
        for X, y in loader:
            for i in range(X.size(0)):
                print('Before data allocation', torch.cuda.memory_allocated())
                X_i = X.squeeze(0).float().to(device)
                y_i = y.squeeze(0).int().to(device)
                print('After data allocation', torch.cuda.memory_allocated())
                preds = torch.sigmoid(model(X_i))
                print('After evaluation', torch.cuda.memory_allocated())
                preds = (preds > 0.5).float()
                num_correct += (preds == y_i).sum()
                num_pixels += torch.numel(preds)
                del preds
        del X
        del y
        gc.collect()
        torch.cuda.empty_cache()
    print(f'Got {num_correct} with acc {num_correct/num_pixels*100:.2f}' )
    model.train()

Let me know if you need anything else.

I’m not able to reproduce the issue with the check_accuracy function and don’t see anything else out of order with a quick check of your code. Try to run this on your system and see if it still produces the memory issue/error:

import torch
import torch.nn as nn
import gc

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True))

    def forward(self, x):
        return self.conv_block(x)


class UNET25D(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNET25D, self).__init__()
        self.upsampling = nn.ModuleList()
        self.downsampling = nn.ModuleList()
        self.pooling = nn.MaxPool2d(kernel_size=2, stride=2)

        # Downsampling
        for feature in features:
            self.downsampling.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Upsampling
        for feature in reversed(features):
            self.upsampling.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
            self.upsampling.append(DoubleConv(feature * 2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for down_block in self.downsampling:
            x = down_block(x)
            skip_connections.append(x)
            x = self.pooling(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for index in range(0, len(self.upsampling), 2):
            x = self.upsampling[index](x)
            skip_connection = skip_connections[index // 2]
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.upsampling[index + 1](concat_skip)

        return self.final_conv(x)


def check_accuracy(model, device):
    model.eval()
    num_correct = 0
    num_pixels = 0

    with torch.no_grad():
        for i in range(100):
            print('Before data allocation', torch.cuda.memory_allocated())
            X_i = torch.rand((1, 3, 64, 64), device=device).float().to(device)
            y_i = torch.rand((1, 1, 64, 64), device=device).int().to(device)
            print('After data allocation', torch.cuda.memory_allocated())
            preds = torch.sigmoid(model(X_i))
            print('After evaluation', torch.cuda.memory_allocated())
            preds = (preds > 0.5).float()
            num_correct += (preds == y_i).sum()
            num_pixels += torch.numel(preds)
            del preds
        del X_i
        del y_i
        gc.collect()
        torch.cuda.empty_cache()
    print(f'Got {num_correct} with acc {num_correct / num_pixels * 100:.2f}')
    model.train()

device = torch.device("cuda:0")
model = UNET25D()
model.to(device)

while True:
    with torch.no_grad():
        check_accuracy(model, device)

If that works fine without any memory issues, put a print statement on X_i.size() in your check_accuracy function before it goes into the model.

On a side note, I found that it runs faster if you remove the del and gc.collect() statements as the GPU has to keep recreating cache, and so winds up with a lot of Copy overhead.

1 Like

Your code worked. When I check for the size of X_i in check_accuracy(), I figured out where I went wrong. I should have implemented the same batching (mini-batches) as I did for the training part. Further, the order of operations when creating X_i was wrong (I had X[i].squeeze(0); it should have been X.squeeze(0), then X[i]). Otherwise, X_i had as its first dimension the length of a whole series. That explains why I was getting various GB sizes in the error message (a big series with ~1000 slices could have easily set the memory to be allocated above 50GB…). I guess that’s also the reason why pin_memory does not work (you can’t pin more memory than what you have). I’m an idiot sometimes… Sorry to have wasted some of your time.

1 Like