18k params lead to 20GB RAM usage

I have a problem with something that just makes no sense to me. I have a block with 18k params and I want it to run with MRI images so the input is 3d, I get out of memory error.
This is the code for the block:

Padding_Mode = 'reflect'

class DilatedConv3DBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DilatedConv3DBlock, self).__init__()
        assert out_channels % 4 == 0  # out_channels must be divisible by 4 so each dilation will have the same n. channels
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels,
                      kernel_size=3, padding='same', padding_mode=Padding_Mode, bias=True),
            nn.InstanceNorm3d(out_channels),
            nn.ReLU(inplace=True),
        )
        kernel_size = 5
        self.conv1 = nn.Conv3d(out_channels, out_channels // 4, kernel_size=kernel_size, dilation=1,
                               padding='same',
                               padding_mode=Padding_Mode, bias=True)
        self.conv2 = nn.Conv3d(out_channels, out_channels // 4, kernel_size=kernel_size, dilation=2,
                               padding='same',
                               padding_mode=Padding_Mode, bias=True)
        self.conv3 = nn.Conv3d(out_channels, out_channels // 4, kernel_size=kernel_size, dilation=4,
                               padding='same',
                               padding_mode=Padding_Mode, bias=True)
        self.conv4 = nn.Conv3d(out_channels, out_channels // 4, kernel_size=kernel_size, dilation=8,
                               padding='same',
                               padding_mode=Padding_Mode, bias=True)
        self.IN = nn.InstanceNorm3d(out_channels)
        self.relu = nn.ReLU()
        self.out_channels = out_channels

    def forward(self, x):
        x = self.conv(x)
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x3 = self.conv3(x)
        x4 = self.conv4(x)
        x = torch.cat([x1, x2, x3, x4], dim=1)
        x = self.IN(x)
        x = self.relu(x)
        return x

This is the code I run:

import os
import torch
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

def cuda_check():
    # Check if CUDA is available
    if torch.cuda.is_available():
        # Get the current memory usage in bytes
        current_memory_usage = torch.cuda.memory_allocated()
        # Get the maximum memory usage seen so far in bytes
        max_memory_usage = torch.cuda.max_memory_allocated()

        # Get the total memory available on the device
        device = torch.device("cuda")
        total_memory = torch.cuda.get_device_properties(device).total_memory

        # Calculate the percentage of memory usage
        memory_usage_percentage = (current_memory_usage / total_memory) * 100

        print(f"Current memory usage: {current_memory_usage / (1024 ** 3)} GB")
        print(f"Maximum memory usage: {max_memory_usage / (1024 ** 3)} GB")
        print(f"Total memory available: {total_memory / (1024 ** 3)} GB")
        print(f"Memory usage percentage: {memory_usage_percentage:.2f}%")
        print()
    else:
        print("CUDA is not available. Please make sure you have CUDA-enabled GPU.")

block = DilatedConv3DBlock(1, 12).to("cuda")
print(sum(p.numel() for p in block.parameters() if p.requires_grad))
cuda_check()
loader = data_loader(['vs-84'], 1, 'train')
x = next(iter(loader))['stripped'].cuda()
cuda_check()
print(x.shape, x.dtype, x.device)    # torch.Size([1, 1, 144, 160, 144]) torch.float32 cuda:0
output = block(x)

And I get this insane output (I removed the path before ‘venv’ here) :

18348
Current memory usage: 7.2479248046875e-05 GB
Maximum memory usage: 7.2479248046875e-05 GB
Total memory available: 19.98974609375 GB
Memory usage percentage: 0.00%

Current memory usage: 0.054759979248046875 GB
Maximum memory usage: 0.054759979248046875 GB
Total memory available: 19.98974609375 GB
Memory usage percentage: 0.27%
. . . . . .
UserWarning: expandable_segments not supported on this platform (Triggered internally at …\c10/cuda/CUDAAllocatorConfig.h:30.)
return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
x2 = self.conv2(x)
File “venv\lib\site-packages\torch\nn\modules\module.py”, line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “venv\lib\site-packages\torch\nn\modules\module.py”, line 1520, in _call_impl
return forward_call(*args, **kwargs)
File “venv\lib\site-packages\torch\nn\modules\conv.py”, line 610, in forward
return self._conv_forward(input, self.weight, self.bias)
File “venv\lib\site-packages\torch\nn\modules\conv.py”, line 594, in _conv_forward
return F.conv3d(
File “venv\lib\site-packages\monai\data\meta_tensor.py”, line 282, in torch_function
ret = super().torch_function(func, types, args, kwargs)
File “venv\lib\site-packages\torch_tensor.py”, line 1418, in torch_function
ret = func(*args, **kwargs)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 18.54 GiB. GPU 0 has a total capacity of 19.99 GiB of which 18.06 GiB is free. Of the allocated memory 748.85 MiB is allocated by PyTorch, and 31.15 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (CUDA semantics — PyTorch 2.2 documentation)

any idea how is that happening? It makes no sense to me at all. Even when I add:

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

as suggested it stays the exactly the same.
TY in advance to everyone, sometimes it feels like the bugs in pytorch are like a black magic I cannot understand

How large is your input size? You might be ignoring the size of intermediate activations, which can use the majority of the device memory especially in CNNs, so the memory usage might be expected.

Thank you for commenting. I totally forgot to add the size, sorry about that.
I’ve editted it to make it easier to understand but the shape is torch.Size([1, 1, 144, 160, 144]) and dtype torch.float32 (I’ve got the same result when using torch.randn instead of my dataloader).
I got the warning:

UserWarning: expandable_segments not supported on this platform (Triggered internally at …\c10/cuda/CUDAAllocatorConfig.h:30.)
return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)

as well but didn’t seem to find anything about it online.
About the activations, there are two activations overall and 3 layers overall (if you call the layer of the 4 conv layers and cat as one layer), shouls it really be that problematic?