Cuda out of memory in 3D U-Net

Hello everyone.

I wondered if anyone else out there was using 3D U-Net in Pytorch and having trouble with Cuda out of memory issue? I’m trying to train a 3D U-Net model on Colab pro (with GPU memory 16GB) to predict 2 classes from 3D medical image with 512512N in size and keep facing cuda out of memory issue.
I’ve try torch.cuda.empty_cache() but doesn’t work.

My U-Net architecture look like this:

class DoubleConv(nn.Module):
“”"(Conv3D → BN → ReLU) * 2"""
def init(self, in_channels, out_channels, num_groups=1):
self.double_conv = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.GroupNorm(num_groups=num_groups, num_channels=out_channels),

        nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.GroupNorm(num_groups=num_groups, num_channels=out_channels),

def forward(self,x):
    return self.double_conv(x)

class Down(nn.Module):

def __init__(self, in_channels, out_channels):
    self.encoder = nn.Sequential(
        nn.MaxPool3d(2, 2),
        DoubleConv(in_channels, out_channels)
def forward(self, x):
    return self.encoder(x)

class Up(nn.Module):

def __init__(self, in_channels, out_channels, trilinear=True):
    if trilinear:
        self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
    self.conv = DoubleConv(in_channels, out_channels)

def forward(self, x1, x2):
    x1 = self.up(x1)

    diffZ = x2.size()[2] - x1.size()[2]
    diffY = x2.size()[3] - x1.size()[3]
    diffX = x2.size()[4] - x1.size()[4]
    x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2, diffZ // 2, diffZ - diffZ // 2])

    x =[x2, x1], dim=1)
    return self.conv(x)

class Out(nn.Module):
def init(self, in_channels, out_channels):
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 1)

def forward(self, x):
    return self.conv(x)

class UNet3d(nn.Module):
def init(self, in_channels, n_classes, n_channels):
self.in_channels = in_channels
self.n_classes = n_classes
self.n_channels = n_channels

    self.conv = DoubleConv(in_channels, n_channels)
    self.enc1 = Down(n_channels, 2 * n_channels)
    self.enc2 = Down(2 * n_channels, 4 * n_channels)
    self.enc3 = Down(4 * n_channels, 8 * n_channels)
    self.enc4 = Down(8 * n_channels, 8 * n_channels)

    self.dec1 = Up(16 * n_channels, 4 * n_channels)
    self.dec2 = Up(8 * n_channels, 2 * n_channels)
    self.dec3 = Up(4 * n_channels, n_channels)
    self.dec4 = Up(2 * n_channels, n_channels)
    self.out = Out(n_channels, n_classes)

def forward(self, x):
    x1 = self.conv(x)
    x2 = self.enc1(x1)
    x3 = self.enc2(x2)
    x4 = self.enc3(x3)
    x5 = self.enc4(x4)

    mask = self.dec1(x5, x4)
    mask = self.dec2(mask, x3)
    mask = self.dec3(mask, x2)
    mask = self.dec4(mask, x1)
    mask = self.out(mask)
    return mask

#I create a random tensor and the error still exist
tensor = torch.randn(1,1,512,512,64)
NET = UNet3d(in_channels=1, n_classes=2, n_channels=1).to(device)
input = tensor.cuda()

RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 14.76 GiB total capacity; 13.42 GiB already allocated; 47.75 MiB free; 13.48 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

You could try to decrease the spatial size of your model, decrease the number of parameters of your model (e.g. less layers or less filters/weights etc.), use torch.utils.checkpoint to trade compute for memory, or maybe also CPU offloading of activations.

1 Like

Thank you. I decreased the channel size of the model and it worked!