Hello everyone!
I’m quite new in pytorch and in deep learning. I’m trying to train a 3D UNet to perform segmentation. To start I’m using the most basic UNet architecture. I’m having problems with the GPU memory. I’m trying to pass one set of four 3D images, brain MRI. To be honest, I don’t know where the problem here can be, I’ll appreciate your help. I’m using a GPU with 10.92 GiB total capacity. Here is my code:
# UNET Network
class UNetConvBlock(nn.Module):
def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
super(UNetConvBlock, self).__init__()
self.conv = nn.Conv3d(in_size, out_size, kernel_size)
self.conv2 = nn.Conv3d(out_size, out_size, kernel_size)
self.activation = activation
def forward(self, x):
out = self.activation(self.conv(x))
out = self.activation(self.conv2(out))
return out
class UNetUpBlock(nn.Module):
def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False):
super(UNetUpBlock, self).__init__()
self.up = nn.ConvTranspose3d(in_size, out_size, 2, stride=2)
self.conv = nn.Conv3d(in_size, out_size, kernel_size)
self.conv2 = nn.Conv3d(out_size, out_size, kernel_size)
self.activation = activation
def center_crop(self, layer, target_size):
batch_size, n_channels, layer_width, layer_height = layer.size()
xy1 = (layer_width - target_size) // 2
return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size)]
def forward(self, x, bridge):
up = self.up(x)
crop1 = self.center_crop(bridge, up.size()[2])
out = torch.cat([up, crop1], 1)
out = self.activation(self.conv(out))
out = self.activation(self.conv2(out))
return out
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
# self.imsize = imsize
self.activation = F.relu
self.pool1 = nn.MaxPool3d(2)
self.pool2 = nn.MaxPool3d(2)
self.pool3 = nn.MaxPool3d(2)
self.pool4 = nn.MaxPool3d(2)
self.conv_block1_64 = UNetConvBlock(4, 64)
self.conv_block64_128 = UNetConvBlock(64, 128)
self.conv_block128_256 = UNetConvBlock(128, 256)
self.conv_block256_512 = UNetConvBlock(256, 512)
self.conv_block512_1024 = UNetConvBlock(512, 1024)
self.up_block1024_512 = UNetUpBlock(1024, 512)
self.up_block512_256 = UNetUpBlock(512, 256)
self.up_block256_128 = UNetUpBlock(256, 128)
self.up_block128_64 = UNetUpBlock(128, 64)
self.last = nn.Conv3d(64, 4, 1)
def forward(self, x):
block1 = self.conv_block1_64(x)
pool1 = self.pool1(block1)
block2 = self.conv_block64_128(pool1)
pool2 = self.pool2(block2)
block3 = self.conv_block128_256(pool2)
pool3 = self.pool3(block3)
block4 = self.conv_block256_512(pool3)
pool4 = self.pool4(block4)
block5 = self.conv_block512_1024(pool4)
up1 = self.up_block1024_512(block5, block4)
up2 = self.up_block512_256(up1, block3)
up3 = self.up_block256_128(up2, block2)
up4 = self.up_block128_64(up3, block1)
return F.log_softmax(self.last(up4))
import torch.optim as optim
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# normalize the images
trainset = norm(trainset)
testset = norm(testset)
trainset = np.expand_dims(trainset,axis=0)
# input_tr = torch.tensor(trainset, dtype=torch.float32, device=device,requires_grad=True)
net = UNet()
net.to(device)
# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)
# in your training loop:
optimizer.zero_grad() # zero the gradient buffers
#define the loss function
criterion = nn.CrossEntropyLoss()
labels = seg.astype(np.uint8)
epochs = 20
for epoch in range(epochs):
running_loss = 0.0
for i,(data,label) in enumerate(zip(trainset, labels)):
input_img = np.expand_dims(data,axis=0)
label_img = np.expand_dims(label,axis=0)
label_img = np.reshape(label_img,-1)
label_img = to_categorical(label_img).astype(np.uint8)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
input_img = torch.tensor(input_img, dtype=torch.float32).to(device)
target = torch.tensor(label_img, dtype=torch.uint8).to(device)
outputs = net(input_img)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
# print statistics
# running_loss += loss.item()
# if i % 2000 == 1999: # print every 2000 mini-batches
# print('[%d, %5d] loss: %.3f' %
# (epoch + 1, i + 1, running_loss / 2000))
# running_loss = 0.0
print('Finished Training')
When I checked the GPU status, after passing the net to the device, the memory occupied is 929MiB and as soon as the training loop start the following error is shown: RuntimeError: CUDA out of memory. Tried to allocate 479.75 MiB (GPU 0; 10.92 GiB total capacity; 9.87 GiB already allocated; 475.50 MiB free; 1.16 MiB cached). The size of the tensor I’m passing is torch.Size([1, 4, 240, 240, 155]).