Hi all.
I am trying to train a model for segmentation. It turns out that the model dice loss does not reduce than 1. It means that the model is not training. Can you tell me how do I make sure that I am at least able to overfit on the training data to get started? I have tried various optimizers and various learning rates.
class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
self.ins = in_conv(n_channels, 16) #128x128x128x1 ==> 128x128x128x16
self.en_1 = Encode3d(16, 32) #128x128x128x16 ==> 64x64x64x32
self.en_2 = Encode3d(32, 64) #64x64x64x32 ==> 32x32x32x64
self.en_3 = Encode3d(64, 128) #32x32x32x64 ==> 16x16x16x128
self.br_4 = Bridge3d(128, 256) #16x16x16x128 ==> 8x8x8x256
self.de_3 = Decode3d(256, 128)#......#8x8x8x256, 16x16x16x128
self.de_2 = Decode3d(128, 64)#.......
self.de_1 = Decode3d(64, 32)#
self.de_0 = Decode3d(32, 1)
# self.out = out_conv(24, 1)
def forward(self, x):
x1 = self.ins(x)
x2 = self.en_1(x1)
x3 = self.en_2(x2)
x4 = self.en_3(x3)
x5 = self.br_4(x4)
x = self.de_3(x5, x4)
x = self.de_2(x, x3)
x = self.de_1(x, x2)
x = self.de_0(x, x1)
return x
class in_conv(nn.Module):
def __init__(self, input_channels, output_channels, kernel_size = 3, stride =1,
padding = 1, bias = True, activation = 'relu', res = False):
nn.Module.__init__(self)
self.input_channels = input_channels
self.output_channels= output_channels
self.residual = res
self.bn_1 = nn.InstanceNorm3d(output_channels, affine = True)
self.bn_2 = nn.InstanceNorm3d(output_channels, affine = True)
self.conv1 = nn.Conv3d(input_channels, output_channels, kernel_size = 3,
stride = 1, padding = 1, bias = True)
self.conv2 = nn.Conv3d(output_channels, output_channels, kernel_size = 3,
stride = 1, padding = 1, bias = True)
def forward(self, x):
# print("Input conv:", x.shape)
if self.residual == True:
skip = x
x = self.conv1(x)
x = self.bn_1(x)
x = F.leaky_relu(x, negative_slope = 0.01, inplace = True)
x = self.conv2(x)
if self.residual == True:
x = x + skip
# print("Exiting Input Conv:", x.shape)
return x
############# ENCODING MODULE #######################################
class Encode3d(nn.Module):
def __init__(self, input_channels, output_channels, res = False):
super(Encode3d, self).__init__()
self.input_channels = input_channels
self.output_channels= output_channels
self.residual = res
self.bn_1 = nn.InstanceNorm3d(input_channels, affine = True)
self.bn_2 = nn.InstanceNorm3d(output_channels, affine = True)
self.conv1 = nn.Conv3d(input_channels, output_channels, kernel_size = 3,
stride = 1, padding = 1, bias = True)
self.conv2 = nn.Conv3d(output_channels, output_channels, kernel_size = 3,
stride = 1, padding = 1, bias = True)
self.maxpool = nn.MaxPool3d(2)
def forward(self, x):
# print("Encoding:", x.shape)
x = self.maxpool(x)
if self.residual == True:
skip = x
x = self.bn_1(x)
x = F.leaky_relu(x)
x = self.conv1(x)
x = self.bn_2(x)
x = F.leaky_relu(x)
x = self.conv2(x)
if self.residual == True:
x = x + skip
# print("Exiting Encoding:", x.shape)
return x
class Bridge3d(nn.Module):
def __init__(self, input_channels, output_channels, res = False):
super(Bridge3d, self).__init__()
self.input_channels = input_channels
self.output_channels= output_channels
self.residual = res
self.bn_1 = nn.InstanceNorm3d(input_channels, affine = True)
self.bn_2 = nn.InstanceNorm3d(output_channels, affine = True)
self.conv1 = nn.Conv3d(input_channels, output_channels, kernel_size = 3,
stride = 1, padding = 1, bias = True)
self.conv2 = nn.Conv3d(output_channels, output_channels, kernel_size = 3,
stride = 1, padding = 1, bias = True)
self.maxpool = nn.MaxPool3d(2)
def forward(self, x):
# print("Bridging", x.shape)
x = self.maxpool(x)
if self.residual == True:
skip = x
x = self.bn_1(x)
x = F.leaky_relu(x)
x = self.conv1(x)
x = self.bn_2(x)
x = F.leaky_relu(x)
x = self.conv2(x)
if self.residual == True:
x = x + skip
# print("Exiting Bridging:", x.shape)
return x
class Decode3d(nn.Module):
def __init__(self, input_channels, output_channels, conv_bias=True,
lrelu_inplace=True, Trilinear = True, res = False):
super(Decode3d, self).__init__()
if Trilinear:
self.up = nn.Upsample(scale_factor = 2, mode = 'trilinear', align_corners = True)
else:
self.up = nn.ConvTranspose3d(input_channels, input_channels,
kernel_size = 3, stride = 2, output_padding = 1)
self.lrelu_inplace = lrelu_inplace
self.conv_bias = conv_bias
self.input_channels = input_channels
self.output_channels= output_channels
self.residual = res
self.conv1 = nn.Conv3d(int(input_channels*1.5), output_channels, kernel_size = 3,
stride = 1,padding = 1, bias=conv_bias)
self.bn_1 = nn.InstanceNorm3d(int(input_channels*1.5), affine=True)
self.conv2 = nn.Conv3d(output_channels, output_channels, kernel_size = 1,
stride = 1,padding = 0, bias=conv_bias)
self.bn_2 = nn.InstanceNorm3d(output_channels, affine=True)
def forward(self, x1, x2):
# print("Decoding x1: ", x1.shape,"x2: ", x2.shape)
x = self.up(x1)
# print("After upconv, X shape: ", x.shape)
x = torch.cat([x, x2], dim = 1)
# print("After concatenating X shape: ", x.shape)
if self.residual == True:
skip = x
x = self.bn_1(x)
# print("Successful_batch_normalization")
x = F.leaky_relu(x)
x = self.conv1(x)
x = self.bn_2(x)
x = F.leaky_relu(x)
x = self.conv2(x)
if self.residual == True:
x = x + skip
# print("Exiting decoding", x.shape)
return x
class out_conv():
def __init__(self, input_channels, output_channels = 1, kernel_size = 3,
stride = 1, padding = 1, bias = True, activation = 'relu', Trilinear = True,
lrelu_inplace = True, conv_bias = True, res = True):
nn.Module.__init__(self)
if Trilinear:
self.up = nn.Upsample(scale_factor = 2, mode = 'trilinear', align_corners = True)
else:
self.up = nn.ConvTranspose3d(input_channels, output_channels = input_channels, kernel_size = 3,
stride = 2, output_padding = 1)
self.lrelu_inplace = lrelu_inplace
self.conv_bias = conv_bias
self.input_channels = input_channels
self.output_channels= output_channels
self.residual = res
self.bn_1 = nn.InstanceNorm3d(int(input_channels*15), affine = True)
self.conv1 = nn.Conv3d(int(input_channels*1.5), output_channels, kernel_size = 3,
stride = 1, padding = 1, bias = True)
self.bn_2 = nn.InstanceNorm3d(input_channels, affine = True)
self.bn_3 = nn.InstanceNorm3d(output_channels, affine = True)
self.conv2 = nn.Conv3d(input_channels, output_channels, kernel_size = 3,
stride = 1, padding = 1, bias = True)
def forward(self, x1, x2):
# print("Out conving x1:", x1.shape, "x2: ", x2.shape)
x = self.up(x1)
# print("After upconving, X shape:", x.shape)
x = torch.cat([x, x2], dim = 1)
# print("After concatenating X shape: ", x.shape)
if self.residual == True:
skip = x
x = self.bn_1(x)
# print("Successful_batch_normalization")
x = F.leaky_relu(x)
x = self.conv1(x)
x = self.bn_2(x)
x = F.leaky_relu(x)
x = self.conv2(x)
if self.residual == True:
x = x + skip
x = self.bn_3(x)
x = nn.sigmoid(x, negative_slope = 0.01, inplace = True)
return x
def train(epoch, train_loss_list):
model.train()
for batch_idx, (image_name, image, mask, affine) in enumerate(train_loader):
if params['cuda']:
image, mask = image.cuda(), mask.cuda() #Loading images into the GPU and ignoring the affine.
image, mask = Variable(image), Variable(mask)
optimizer.zero_grad()
output = model(image)
loss = criterion(output, mask)
train_loss_list.append(loss.data.item())
loss.backward()
optimizer.step()
if batch_idx % int(params['log_interval']) == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tAverage DICE Loss: {:.6f}'.format(
epoch, batch_idx * len(image), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data.item()))
for param_group in optimizer.param_groups:
print("Learning rate: ", param_group['lr'])
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'loss_state_dict' : criterion.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss.data.item()
}, './tmp/model.pkl')
sys.stdout.flush()
criterion = DiceLoss()
for i in range(int(params['max_epochs'])):
if i % int(params['val_epoch_interval']):
create_epoch_dir(i, params['epoch'])
train(i, train_loss_list)
class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()
def forward(self, output, mask):
probs = torch.squeeze(output, 1)
mask = torch.squeeze(mask, 1)
intersection = probs * mask
intersection = torch.sum(intersection, 2)
intersection = torch.sum(intersection, 1)
den1 = probs * probs
den1 = torch.sum(den1, 2)
den1 = torch.sum(den1, 1)
den2 = mask * mask
den2 = torch.sum(den2, 2)
den2 = torch.sum(den2, 1)
eps = 1e-8
dice = 2 * ((intersection + eps) / (den1 + den2 + eps))
# dice_eso = dice[:, 1:]
dice_eso = dice
loss = 1 - torch.sum(dice_eso) / dice_eso.size(0)
return loss
Here is a rough structure of the model. It is a simple UNet.
Can anyone help me with this?
Thanks.