3D UNet for brain tumor segmentation - problems with GPU memory

Hello everyone!
I’m quite new in pytorch and in deep learning. I’m trying to train a 3D UNet to perform segmentation. To start I’m using the most basic UNet architecture. I’m having problems with the GPU memory. I’m trying to pass one set of four 3D images, brain MRI. To be honest, I don’t know where the problem here can be, I’ll appreciate your help. I’m using a GPU with 10.92 GiB total capacity. Here is my code:

# UNET Network 

class UNetConvBlock(nn.Module):
    
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
        
        super(UNetConvBlock, self).__init__()
        
        self.conv = nn.Conv3d(in_size, out_size, kernel_size)
        self.conv2 = nn.Conv3d(out_size, out_size, kernel_size)
        self.activation = activation

    def forward(self, x):
        out = self.activation(self.conv(x))
        out = self.activation(self.conv2(out))

        return out


class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False):
        super(UNetUpBlock, self).__init__()
        self.up = nn.ConvTranspose3d(in_size, out_size, 2, stride=2)
        self.conv = nn.Conv3d(in_size, out_size, kernel_size)
        self.conv2 = nn.Conv3d(out_size, out_size, kernel_size)
        self.activation = activation

    def center_crop(self, layer, target_size):
        batch_size, n_channels, layer_width, layer_height = layer.size()
        xy1 = (layer_width - target_size) // 2
        return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size)]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.size()[2])
        out = torch.cat([up, crop1], 1)
        out = self.activation(self.conv(out))
        out = self.activation(self.conv2(out))

        return out


class UNet(nn.Module):
    
    def __init__(self):
        
        super(UNet, self).__init__()
#         self.imsize = imsize

        self.activation = F.relu
        
        self.pool1 = nn.MaxPool3d(2)
        self.pool2 = nn.MaxPool3d(2)
        self.pool3 = nn.MaxPool3d(2)
        self.pool4 = nn.MaxPool3d(2)

        self.conv_block1_64 = UNetConvBlock(4, 64)
        self.conv_block64_128 = UNetConvBlock(64, 128)
        self.conv_block128_256 = UNetConvBlock(128, 256)
        self.conv_block256_512 = UNetConvBlock(256, 512)
        self.conv_block512_1024 = UNetConvBlock(512, 1024)

        self.up_block1024_512 = UNetUpBlock(1024, 512)
        self.up_block512_256 = UNetUpBlock(512, 256)
        self.up_block256_128 = UNetUpBlock(256, 128)
        self.up_block128_64 = UNetUpBlock(128, 64)

        self.last = nn.Conv3d(64, 4, 1)


    def forward(self, x):
        block1 = self.conv_block1_64(x)
        pool1 = self.pool1(block1)

        block2 = self.conv_block64_128(pool1)
        pool2 = self.pool2(block2)

        block3 = self.conv_block128_256(pool2)
        pool3 = self.pool3(block3)

        block4 = self.conv_block256_512(pool3)
        pool4 = self.pool4(block4)

        block5 = self.conv_block512_1024(pool4)

        up1 = self.up_block1024_512(block5, block4)

        up2 = self.up_block512_256(up1, block3)

        up3 = self.up_block256_128(up2, block2)

        up4 = self.up_block128_64(up3, block1)

        return F.log_softmax(self.last(up4))
import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# normalize the images
trainset = norm(trainset)
testset = norm(testset)

trainset = np.expand_dims(trainset,axis=0)

# input_tr = torch.tensor(trainset, dtype=torch.float32, device=device,requires_grad=True) 

net = UNet()
net.to(device)

# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)
# in your training loop:
optimizer.zero_grad()   # zero the gradient buffers
#define the loss function
criterion = nn.CrossEntropyLoss()

labels = seg.astype(np.uint8)


epochs = 20

for epoch in range(epochs): 
    
    running_loss = 0.0

    for i,(data,label) in enumerate(zip(trainset, labels)):
        
        input_img = np.expand_dims(data,axis=0)
        label_img = np.expand_dims(label,axis=0)
        
        label_img = np.reshape(label_img,-1)
        label_img = to_categorical(label_img).astype(np.uint8)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        input_img = torch.tensor(input_img, dtype=torch.float32).to(device)
        target = torch.tensor(label_img, dtype=torch.uint8).to(device)
        outputs = net(input_img)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()
                

        # print statistics
#         running_loss += loss.item()
#         if i % 2000 == 1999:    # print every 2000 mini-batches
#             print('[%d, %5d] loss: %.3f' %
#                   (epoch + 1, i + 1, running_loss / 2000))
#             running_loss = 0.0

print('Finished Training')

When I checked the GPU status, after passing the net to the device, the memory occupied is 929MiB and as soon as the training loop start the following error is shown: RuntimeError: CUDA out of memory. Tried to allocate 479.75 MiB (GPU 0; 10.92 GiB total capacity; 9.87 GiB already allocated; 475.50 MiB free; 1.16 MiB cached). The size of the tensor I’m passing is torch.Size([1, 4, 240, 240, 155]).

Could you try to lower the number of filter kernels in your layers and just check, how much memory will be used then? The intermediate activations might take a lot of memory when you are training your model.

For what I know the input size of UNet is recommanded to be 512 x 512, which is a very large size., you could try:

  • A smaller batch size
  • A smaller input image size

[1,4,240,240,155] is quite a large size, and an encoder-decoder network is a very memory demanding framework. You may have to use patch-wise training or use fewer channels in the first and latest blocks. Using detection prior to segmentation is also a good idea.