When I am using a basic U-Net architecture (referenced at the bottom) and run the following code:
import torch
from torch import nn
import torch.nn.functional as F
from torch import cuda
from functools import partial
import segmentation_models_pytorch as smp
batch_size = 4
device3 = torch.device("cuda:" + str(3))
UNet = BasicUNet(in_channel=1, out_channel=1).to(device3)
n_para3 = sum(p.numel() for p in UNet.parameters()) # 7787393
x = torch.randn((batch_size, 1, 256, 256)).to(device3)
pred = UNet(x)
print(torch.cuda.memory_summary(device=device3))
I get the following output for the memory consumption:
# |===========================================================================|
# | PyTorch CUDA memory summary, device ID 3 |
# |---------------------------------------------------------------------------|
# | CUDA OOMs: 0 | cudaMalloc retries: 0 |
# |===========================================================================|
# | Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |
# |---------------------------------------------------------------------------|
# | Allocated memory | 1787 MB | 1978 MB | 2313 MB | 538928 KB |
# | from large pool | 1782 MB | 1974 MB | 2304 MB | 534726 KB |
# | from small pool | 4 MB | 5 MB | 8 MB | 4202 KB |
# |---------------------------------------------------------------------------|
# | Active memory | 1787 MB | 1978 MB | 2313 MB | 538928 KB |
# | from large pool | 1782 MB | 1974 MB | 2304 MB | 534726 KB |
# | from small pool | 4 MB | 5 MB | 8 MB | 4202 KB |
So far so good. However, I also played around with this package that provides some predefined models. And running similar code
device1 = torch.device("cuda:" + str(1))
UNet_args = {
"encoder_name": "resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
"encoder_weights": "imagenet", # use `imagenet` pre-trained weights for encoder initialization
"in_channels": 1, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
"activation": None,
"classes": 1 # model output channels (number of classes in your dataset)
}
UNet = smp.Unet(**UNet_args).to(device1)
x = torch.randn((batch_size, 1, 256, 256)).to(device1)
pred = UNet(x)
n_para1 = sum(p.numel() for p in UNet.parameters()) # 14321937
print(torch.cuda.memory_summary(device=device1))
which lead to the following output:
# |===========================================================================|
# | PyTorch CUDA memory summary, device ID 1 |
# |---------------------------------------------------------------------------|
# | CUDA OOMs: 0 | cudaMalloc retries: 0 |
# |===========================================================================|
# | Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed |
# |---------------------------------------------------------------------------|
# | Allocated memory | 385224 KB | 385608 KB | 562387 KB | 177163 KB |
# | from large pool | 360448 KB | 360448 KB | 532600 KB | 172152 KB |
# | from small pool | 24776 KB | 25160 KB | 29787 KB | 5011 KB |
# |---------------------------------------------------------------------------|
# | Active memory | 385224 KB | 385608 KB | 562387 KB | 177163 KB |
# | from large pool | 360448 KB | 360448 KB | 532600 KB | 172152 KB |
# | from small pool | 24776 KB | 25160 KB | 29787 KB | 5011 KB |
And now my question: What am I doing wrong? Why does my model use almost 5 times as much memory even though it has only ~0.5 as many parameters (14mio vs 7mio)? Where is all that memory used? I looked through the source code in of the segmentation_models_pytorch package and it all looks fairly ‘standard’, or at least I couldn’t spot anything that tries to reduce memory consumption. Or is the parameter count just not reliable but looking at the printed model architecture also supports the fact that the pre-trained model is a lot bigger (it basically uses two of the layeredConv
modules in a row instead of one like my model does).
UNet code:
class layeredConv(nn.Module):
"""
Module to stack two convolutional layers
"""
def __init__(self, in_channel, out_channel):
super().__init__()
self.layers = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(),
nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU()
)
def forward(self, x):
return self.layers(x)
class downsample(nn.Module):
"""
Module to implement the downsampling parts of the UNet and apply convolutional layers.
Because we need the output of the last convolutional layer to combine it in the upward
branch later, we do the downsampling by the MaxPool first.
"""
def __init__(self, in_channel, out_channel):
super().__init__()
self.layers = nn.Sequential(
nn.MaxPool2d(kernel_size=2),
layeredConv(in_channel, out_channel)
)
def forward(self, x):
return self.layers(x)
class upsample(nn.Module):
"""
Module to implement the upwards sampling part of the UNet. It is more complicated because we now
have to combine and reshape the parts from the downward facing branch
"""
def __init__(self, in_channel, out_channel):
super().__init__()
self.use_transpose = False
if self.use_transpose:
self.up = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2)
self.conv = layeredConv(in_channel, out_channel)
else:
self.up = partial(F.interpolate, scale_factor=2, mode="nearest")
self.conv = layeredConv(in_channel + out_channel, out_channel)
def forward(self, x1, x2):
"""
Combining the upsampled part(x1) with the old input (x2)
"""
x1 = self.up(x1)
# We now have to reshape x1 such that it fits x2
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
# From the docs:
# [...] to pad the last 2 dimensions of the input tensor, then use
# (padding_left,padding_right, padding_top, padding_bottom )
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
# Now we are able to concatenate both inputs
x = torch.cat([x2, x1], dim=1)
# ... and apply the convolutional layers again
x = self.conv(x)
return x
class outputLayer(nn.Module):
"""
Applies a 1x1 conv later to get the number of channels to the specified out_channel size
"""
def __init__(self, in_channel, out_channel):
super().__init__()
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0)
def forward(self, x):
x = self.conv(x)
return x
class BasicUNet(nn.Module):
def __init__(self, in_channel, out_channel, bilinear=None):
super().__init__()
self.in_channel = in_channel
self.out_channel = out_channel
# Now lets start to build the actual layers
self.initial_layer = layeredConv(in_channel, 64)
self.down1 = downsample(64, 128)
self.down2 = downsample(128, 256)
self.down3 = downsample(256, 512)
self.up2 = upsample(512, 256)
self.up3 = upsample(256, 128)
self.up4 = upsample(128, 64)
# But we still need the final layer (1x1 conv-layer) to get the number of latent dim we want
self.hidden = nn.Identity()
self.final_layer = outputLayer(64, out_channel)
self.sigmoid = nn.Sigmoid()
def comp_feature(self, x):
# Downward branch
x1 = self.initial_layer(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
# x5 = self.down4(x4)
# Upward branch
# x = self.up1(x5, x4)
x = self.up2(x4, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.hidden(x)
return x
def forward(self, x):
"""
Forwards x through the network.
"""
x = self.comp_feature(x)
x = self.final_layer(x)
return self.sigmoid(x)