Memory consumption U-Net

When I am using a basic U-Net architecture (referenced at the bottom) and run the following code:

import torch
from torch import nn
import torch.nn.functional as F
from torch import cuda
from functools import partial
import segmentation_models_pytorch as smp

batch_size = 4

device3 = torch.device("cuda:" + str(3))
UNet = BasicUNet(in_channel=1, out_channel=1).to(device3)
n_para3 = sum(p.numel() for p in UNet.parameters())   # 7787393
x = torch.randn((batch_size, 1, 256, 256)).to(device3)
pred = UNet(x)

print(torch.cuda.memory_summary(device=device3))

I get the following output for the memory consumption:

# |===========================================================================|
# |                  PyTorch CUDA memory summary, device ID 3                 |
# |---------------------------------------------------------------------------|
# |            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
# |===========================================================================|
# |        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
# |---------------------------------------------------------------------------|
# | Allocated memory      |    1787 MB |    1978 MB |    2313 MB |  538928 KB |
# |       from large pool |    1782 MB |    1974 MB |    2304 MB |  534726 KB |
# |       from small pool |       4 MB |       5 MB |       8 MB |    4202 KB |
# |---------------------------------------------------------------------------|
# | Active memory         |    1787 MB |    1978 MB |    2313 MB |  538928 KB |
# |       from large pool |    1782 MB |    1974 MB |    2304 MB |  534726 KB |
# |       from small pool |       4 MB |       5 MB |       8 MB |    4202 KB |

So far so good. However, I also played around with this package that provides some predefined models. And running similar code



device1 = torch.device("cuda:" + str(1))
UNet_args = {
    "encoder_name": "resnet18",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    "encoder_weights": "imagenet",  # use `imagenet` pre-trained weights for encoder initialization
    "in_channels": 1,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    "activation": None,
    "classes": 1  # model output channels (number of classes in your dataset)
}
UNet = smp.Unet(**UNet_args).to(device1)

x = torch.randn((batch_size, 1, 256, 256)).to(device1)
pred = UNet(x)

n_para1 = sum(p.numel() for p in UNet.parameters())   # 14321937
print(torch.cuda.memory_summary(device=device1))

which lead to the following output:

# |===========================================================================|
# |                  PyTorch CUDA memory summary, device ID 1                 |
# |---------------------------------------------------------------------------|
# |            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
# |===========================================================================|
# |        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
# |---------------------------------------------------------------------------|
# | Allocated memory      |  385224 KB |  385608 KB |  562387 KB |  177163 KB |
# |       from large pool |  360448 KB |  360448 KB |  532600 KB |  172152 KB |
# |       from small pool |   24776 KB |   25160 KB |   29787 KB |    5011 KB |
# |---------------------------------------------------------------------------|
# | Active memory         |  385224 KB |  385608 KB |  562387 KB |  177163 KB |
# |       from large pool |  360448 KB |  360448 KB |  532600 KB |  172152 KB |
# |       from small pool |   24776 KB |   25160 KB |   29787 KB |    5011 KB |

And now my question: What am I doing wrong? Why does my model use almost 5 times as much memory even though it has only ~0.5 as many parameters (14mio vs 7mio)? Where is all that memory used? I looked through the source code in of the segmentation_models_pytorch package and it all looks fairly ‘standard’, or at least I couldn’t spot anything that tries to reduce memory consumption. Or is the parameter count just not reliable but looking at the printed model architecture also supports the fact that the pre-trained model is a lot bigger (it basically uses two of the layeredConv modules in a row instead of one like my model does).


UNet code:

class layeredConv(nn.Module):
    """
    Module to stack two convolutional layers
    """

    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )

    def forward(self, x):
        return self.layers(x)


class downsample(nn.Module):
    """
    Module to implement the downsampling parts of the UNet and apply convolutional layers.
    Because we need the output of the last convolutional layer to combine it in the upward
    branch later, we do the downsampling by the MaxPool first.
    """

    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.layers = nn.Sequential(
            nn.MaxPool2d(kernel_size=2),
            layeredConv(in_channel, out_channel)
        )

    def forward(self, x):
        return self.layers(x)


class upsample(nn.Module):
    """
    Module to implement the upwards sampling part of the UNet. It is more complicated because we now
    have to combine and reshape the parts from the downward facing branch
    """

    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.use_transpose = False
        if self.use_transpose:
            self.up = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2)
            self.conv = layeredConv(in_channel, out_channel)
        else:
            self.up = partial(F.interpolate, scale_factor=2, mode="nearest")
            self.conv = layeredConv(in_channel + out_channel, out_channel)

    def forward(self, x1, x2):
        """
        Combining the upsampled part(x1) with the old input (x2)
        """
        x1 = self.up(x1)

        # We now have to reshape x1 such that it fits x2
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        # From the docs:
        # [...] to pad the last 2 dimensions of the input tensor, then use
        # (padding_left,padding_right, padding_top, padding_bottom )
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])

        # Now we are able to concatenate both inputs
        x = torch.cat([x2, x1], dim=1)
        # ... and apply the convolutional layers again
        x = self.conv(x)
        return x


class outputLayer(nn.Module):
    """
    Applies a 1x1 conv later to get the number of channels to the specified out_channel size
    """

    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = self.conv(x)
        return x


class BasicUNet(nn.Module):
    def __init__(self, in_channel, out_channel, bilinear=None):
        super().__init__()

        self.in_channel = in_channel
        self.out_channel = out_channel

        # Now lets start to build the actual layers
        self.initial_layer = layeredConv(in_channel, 64)
        self.down1 = downsample(64, 128)
        self.down2 = downsample(128, 256)

        self.down3 = downsample(256, 512)

        self.up2 = upsample(512, 256)
        self.up3 = upsample(256, 128)
        self.up4 = upsample(128, 64)

        # But we still need the final layer (1x1 conv-layer) to get the number of latent dim we want
        self.hidden = nn.Identity()
        self.final_layer = outputLayer(64, out_channel)

        self.sigmoid = nn.Sigmoid()

    def comp_feature(self, x):
        # Downward branch
        x1 = self.initial_layer(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        # x5 = self.down4(x4)

        # Upward branch
        # x = self.up1(x5, x4)
        x = self.up2(x4, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        x = self.hidden(x)
        return x

    def forward(self, x):
        """
        Forwards x through the network.
        """
        x = self.comp_feature(x)
        x = self.final_layer(x)

        return self.sigmoid(x)


Depending on the model architecture the number of parameters and buffers might use less memory than the computed forward activations needed for the gradient calculation. This post explains it with an example.

Yes, using more layers, will create more intermediate activations, which need to be stored for the backward pass assuming you want to train the model.

Thanks for your reply and sorry for continuing with another question as I still don’t understand what’s going on. So I run the following code (which the class and function defined at the bottom of this post):

device2 = torch.device("cuda:" + str(2))
print(f'Before model memory:       {convert_size(torch.cuda.memory_allocated(device=device2))}')
res2 = BasicEnc(1, 1).to(device2)
print(f'After model memory:         {convert_size(torch.cuda.memory_allocated(device=device2))}')
x = torch.randn((batch_size, 1, 256, 256)).to(device2)
print(f'After input memory:        {convert_size(torch.cuda.memory_allocated(device=device2))}')


x1 = res2.initial_layer(x)
print(f'Stage: {0} memory:           {convert_size(torch.cuda.memory_allocated(device=device2))}')
x2 = res2.down1(x1)
print(f'Stage: {1} memory:           {convert_size(torch.cuda.memory_allocated(device=device2))}')
x3 = res2.down2(x2)
print(f'Stage: {2} memory:           {convert_size(torch.cuda.memory_allocated(device=device2))}')
x4 = res2.down3(x3)
print(f'Stage: {3} memory:           {convert_size(torch.cuda.memory_allocated(device=device2))}')

and I get the following output:

Before model memory:       0B
After model memory:        17.9 MB
After input memory:        18.9 MB
Stage: 0 memory:           274.91 MB
Stage: 1 memory:           450.91 MB
Stage: 2 memory:           538.91 MB
Stage: 3 memory:           582.92 MB

I know that the computational graph has to save all the intermediate values in order to backpropagate through it later which I guess should roughly double the memory consumption of each layer?. But why the jump from 19MB to 275 MB?

This is the code defining the model:

import torch
from torch import nn
import torch.nn.functional as F
from torch import cuda
import math

batch_size = 4

def convert_size(size_bytes):
    if size_bytes == 0:
        return "0B"
    size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
    i = int(math.floor(math.log(size_bytes, 1024)))
    p = math.pow(1024, i)
    s = round(size_bytes / p, 2)
    return "%s %s" % (s, size_name[i])


class layeredConv(nn.Module):
    """
    Module to stack two convolutional layers
    """

    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.layers(x)


class downsample(nn.Module):
    """
    Module to implement the downsampling parts of the UNet and apply convolutional layers.
    """

    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.layers = nn.Sequential(
            nn.MaxPool2d(kernel_size=2),
            layeredConv(in_channel, out_channel)
        )

    def forward(self, x):
        return self.layers(x)


class BasicEnc(nn.Module):
    def __init__(self, in_channel, out_channel, bilinear=None):
        super().__init__()

        self.in_channel = in_channel
        self.out_channel = out_channel

        self.initial_layer = layeredConv(in_channel, 64)
        self.down1 = downsample(64, 128)
        self.down2 = downsample(128, 256)
        self.down3 = downsample(256, 512)

It depends on the activation size which can be much larger than the model parameters, especially if you are using a ConvNet.
Here is an example:

import torch
import torch.nn as nn

device2 = torch.device("cuda")
print(torch.cuda.memory_allocated()/1024**2)
# 0.0

conv = nn.Conv2d(1, 512, 3, 1, 1).cuda()
print(torch.cuda.memory_allocated()/1024**2)
# 0.01953125

x = torch.randn((1, 1, 256, 256)).to(device2)
print(torch.cuda.memory_allocated()/1024**2)
# 0.26953125

out = conv(x)
print(torch.cuda.memory_allocated()/1024**2)
# 128.26953125

print(out.nelement() * out.element_size() / 1024**2)
# 128.0

The conv layer has only 5120 elements (conv.weight.nelement() + conv.bias.nelement()) while the output has 33554432 and will use more memory as seen in my example.