Hi!
I am currently writing a UNet implementation for medical image segmentation. However, I am struggling with memory problems, both when running with GPU and CPU. I’ve already read some posts and managed to solve the leaks by deleting intermediate tensors when doing the forward pass, but still my model requires around 16GB of RAM at some moments.
When only loading the data, this already requires 2.5GB of RAM. Then performing a forward pass with a batch-size of 5 pushes this to 16GB at most. You can find most of my code below, all tips on how to write memory efficient code or inspecting what parts require a lot of memory are really appreciated.
Dataset
class MedicalData(Dataset):
patches = []
labels = []
def __init__(self, dataset_location, transform=None):
self.transform = transform
for file in os.listdir(dataset_location):
filename = os.fsdecode(file)
#Load the training data from pickle files
if filename.startswith('patches'):
with open(dataset_location + filename, 'rb') as handle:
patches = pickle.load(handle)
self.patches.append(patches)
masks_filename = 'masks' + filename.split('patches')[1]
print(filename, masks_filename)
with open(dataset_location + masks_filename, 'rb') as handle:
labels = pickle.load(handle)
self.labels.append(labels)
#Flatten the lists into one list
self.patches = [item for sublist in self.patches for item in sublist]
self.labels = [item for sublist in self.labels for item in sublist]
self.patches = self.patches[:10]
self.labels = self.labels[:10]
assert (len(self.patches) == len(self.labels))
def __getitem__(self, index):
patch = self.patches[index]
patch = np.expand_dims(patch, axis=0)
if self.transform is not None:
patch = self.transform(patch)
# Convert patch and label to torch tensors
patch = torch.from_numpy(np.asarray(patch))
label = torch.from_numpy(np.asarray(self.labels[index]))
#Convert uint8 to float tensors
patch = patch.type(torch.FloatTensor)
label = label.type(torch.FloatTensor)
return patch, label
UNet
class Unet(nn.Module):
#input_channels is number of channels in input image
#num_filters is the amount of filters in the first conv layer
def __init__(self, input_channels, num_classes, num_filters, depth, padding=False):
super(Unet, self).__init__()
self.input_channels = input_channels
self.num_classes = num_classes
self.num_filters = num_filters
self.depth = depth
self.padding = padding
print("No of classes: %d \nNo of input channels: %d \nNo of filters first layer: %d \nDepth of the network: %d \nPadding: %d"
% (self.num_classes, self.input_channels, self.num_filters, self.depth, self.padding))
self.contracting_path = nn.ModuleList()
for i in range(depth):
input = self.input_channels if i == 0 else output
output = self.num_filters*(2**i)
self.contracting_path.append(DownConvBlock(input, output, padding))
self.upsampling_path = nn.ModuleList()
for i in range(depth-1):
input = output
output = input // 2
self.upsampling_path.append(UpConvBlock(input, output, padding))
self.last_layer = nn.Conv2d(output, num_classes, kernel_size=1)
def forward(self, x):
blocks = []
for i, down in enumerate(self.contracting_path):
x = down(x)
if i != len(self.contracting_path)-1:
blocks.append(x)
#x = F.avg_pool2d(x, 2)
x = F.max_pool2d(x, 2)
for i, up in enumerate(self.upsampling_path):
x = up(x, blocks[-i-1])
del blocks #Delete to fix memory leak
return self.last_layer(x)
class DownConvBlock(nn.Module):
def __init__(self, input_dim, output_dim, padding):
super(DownConvBlock, self).__init__()
layers = []
layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=int(padding)))
layers.append(nn.BatchNorm2d(output_dim))
layers.append(nn.ReLU(inplace=True)) #Inplace is true is used to save memory
layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=int(padding)))
layers.append(nn.BatchNorm2d(output_dim))
layers.append(nn.ReLU(inplace=True))
self.layers = nn.Sequential(*layers)
def forward(self, patch):
return self.layers(patch)
class UpConvBlock(nn.Module):
def __init__(self, input_dim, output_dim, padding, bilinear=False):
super(UpConvBlock, self).__init__()
if bilinear:
#self.upconv_layer = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.upconv_layer = nn.Sequential(nn.Upsample(mode='bilinear', scale_factor=2, align_corners=True), nn.Conv2d(input_dim, output_dim, kernel_size=1))
else:
self.upconv_layer = nn.ConvTranspose2d(input_dim, output_dim, kernel_size=2, stride=2)
self.conv_block = DownConvBlock(input_dim, output_dim, padding)
def forward(self, x, bridge):
up = self.upconv_layer(x)
crop1 = self.center_crop(bridge, up.shape[2:])
out = torch.cat([up, crop1], 1)
del up
del crop1
return self.conv_block(out)