Writing memory efficient code

Hi!

I am currently writing a UNet implementation for medical image segmentation. However, I am struggling with memory problems, both when running with GPU and CPU. I’ve already read some posts and managed to solve the leaks by deleting intermediate tensors when doing the forward pass, but still my model requires around 16GB of RAM at some moments.

When only loading the data, this already requires 2.5GB of RAM. Then performing a forward pass with a batch-size of 5 pushes this to 16GB at most. You can find most of my code below, all tips on how to write memory efficient code or inspecting what parts require a lot of memory are really appreciated.

Dataset

class MedicalData(Dataset):
    patches = []
    labels = []

    def __init__(self, dataset_location, transform=None):
        self.transform = transform

        for file in os.listdir(dataset_location):
            filename = os.fsdecode(file)

            #Load the training data from pickle files
            if filename.startswith('patches'):
                with open(dataset_location + filename, 'rb') as handle:
                    patches = pickle.load(handle)
                    self.patches.append(patches)

                    masks_filename = 'masks' + filename.split('patches')[1]
                    print(filename, masks_filename)
                    with open(dataset_location + masks_filename, 'rb') as handle:
                        labels = pickle.load(handle)
                        self.labels.append(labels)

        #Flatten the lists into one list
        self.patches = [item for sublist in self.patches for item in sublist]
        self.labels = [item for sublist in self.labels for item in sublist]

        self.patches = self.patches[:10]
        self.labels = self.labels[:10]
        assert (len(self.patches) == len(self.labels))

    def __getitem__(self, index):
        patch = self.patches[index]
        patch = np.expand_dims(patch, axis=0)

        if self.transform is not None:
            patch = self.transform(patch)

        # Convert patch and label to torch tensors
        patch = torch.from_numpy(np.asarray(patch))
        label = torch.from_numpy(np.asarray(self.labels[index]))

        #Convert uint8 to float tensors
        patch = patch.type(torch.FloatTensor)
        label = label.type(torch.FloatTensor)

        return patch, label

UNet

class Unet(nn.Module):
    #input_channels is number of channels in input image
    #num_filters is the amount of filters in the first conv layer

    def __init__(self, input_channels, num_classes, num_filters, depth, padding=False):
        super(Unet, self).__init__()
        self.input_channels = input_channels
        self.num_classes = num_classes
        self.num_filters = num_filters
        self.depth = depth
        self.padding = padding
        print("No of classes: %d \nNo of input channels: %d \nNo of filters first layer: %d \nDepth of the network: %d \nPadding: %d"
        % (self.num_classes, self.input_channels, self.num_filters, self.depth, self.padding))

        self.contracting_path = nn.ModuleList()
        for i in range(depth):
            input = self.input_channels if i == 0 else output
            output = self.num_filters*(2**i)
            self.contracting_path.append(DownConvBlock(input, output, padding))

        self.upsampling_path = nn.ModuleList()
        for i in range(depth-1):
            input = output
            output = input // 2
            self.upsampling_path.append(UpConvBlock(input, output, padding))

        self.last_layer = nn.Conv2d(output, num_classes, kernel_size=1)

    def forward(self, x):
        blocks = []
        for i, down in enumerate(self.contracting_path):
            x = down(x)
            if i != len(self.contracting_path)-1:
                blocks.append(x)
                #x = F.avg_pool2d(x, 2)
                x = F.max_pool2d(x, 2)

        for i, up in enumerate(self.upsampling_path):
            x = up(x, blocks[-i-1])

        del blocks #Delete to fix memory leak
        return self.last_layer(x)
class DownConvBlock(nn.Module):
    def __init__(self, input_dim, output_dim, padding):
        super(DownConvBlock, self).__init__()
        layers = []

        layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=int(padding)))
        layers.append(nn.BatchNorm2d(output_dim))
        layers.append(nn.ReLU(inplace=True)) #Inplace is true is used to save memory

        layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=int(padding)))
        layers.append(nn.BatchNorm2d(output_dim))
        layers.append(nn.ReLU(inplace=True))

        self.layers = nn.Sequential(*layers)

    def forward(self, patch):
        return self.layers(patch)


class UpConvBlock(nn.Module):
    def __init__(self, input_dim, output_dim, padding, bilinear=False):
        super(UpConvBlock, self).__init__()
        if bilinear:
            #self.upconv_layer = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.upconv_layer = nn.Sequential(nn.Upsample(mode='bilinear', scale_factor=2, align_corners=True), nn.Conv2d(input_dim, output_dim, kernel_size=1))
        else:
            self.upconv_layer =  nn.ConvTranspose2d(input_dim, output_dim, kernel_size=2, stride=2)
        self.conv_block = DownConvBlock(input_dim, output_dim, padding)

    def forward(self, x, bridge):
        up = self.upconv_layer(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        del up
        del crop1
        return self.conv_block(out)

Hi,

This looks like a fairly standard model.
Just a few questions

  • What is the value of depth, num_filters?
  • What is the size of your inputs in width and height?
  • This output = self.num_filters*(2**i) looks like it’s going to grow very very fast no? Is that expected to be exponential in the depth?

Hi, thanks for your reply!

The num_filters is 64 and the depth is 4. The input images are 512x512x1.

This is an example of the network’s structure:

Unet(
  (contracting_path): ModuleList(
    (0): DownConvBlock(
      (layers): Sequential(
        (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace)
      )
    )
    (1): DownConvBlock(
      (layers): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace)
      )
    )
    (2): DownConvBlock(
      (layers): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace)
      )
    )
    (3): DownConvBlock(
      (layers): Sequential(
        (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
        (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace)
      )
    )
  )
  (upsampling_path): ModuleList(
    (0): UpConvBlock(
      (upconv_layer): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
      (conv_block): DownConvBlock(
        (layers): Sequential(
          (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace)
          (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace)
        )
      )
    )
    (1): UpConvBlock(
      (upconv_layer): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
      (conv_block): DownConvBlock(
        (layers): Sequential(
          (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace)
        )
      )
    )
    (2): UpConvBlock(
      (upconv_layer): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
      (conv_block): DownConvBlock(
        (layers): Sequential(
          (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace)
          (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace)
        )
      )
    )
  )
  (last_layer): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))
)

Hi,

Maybe you have a bit too many channels? Convs like Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) on a 512x512x512 image is going to be quite memory heavy !
It will have ((512+2*1-1)/1+1)*((512+2*1-1)/1+1) = 264196 patches each of size 512 x 3x3 = 4608.
Which means that the unfolding for the convolution will contain batch x 1 217 415 168 elements.

Thanks for having a look! I am running this model on my 1080Ti and I get CUDA out of memory immediately. What is the best way to spread the load between my RAM and my GPU, so that I can run the full model? Thanks!

Spreading between RAM and gpu might be a bit tricky to do. There is no tool to do it at the moment.
You can try to put your large DownConvBlock and UpConvBlock into a checkpoint. This is built just for that purpose.

Clear! I will have a look at checkpoints. If you have any other tips for memory efficiency, please tell me! :slight_smile: