Autoencoder loss doesnot vary

Hello Everyone,

I am training an Autoencoder based on Resnet-Unet Architecture. Here the loss remains constant through out training. I tried varying the learning rate, Used learning rate scheduler, played arround with different optimizers and loss functions(SSE, BCE etc). Used normalized and unnormalized data .I followed the suggestions provided by in the pytorch forum. But was unable to fix the problem. It would be great if someone can point out where i am going wrong.

Thank you

Code

batch_size=3
def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
    )


class ResNetUNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        self.base_model = models.resnet18(pretrained=True)
        self.base_layers = list(self.base_model.children())
        self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
        self.layer1_1x1 = convrelu(64, 64, 1, 0)
        self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)
        self.layer2_1x1 = convrelu(128, 128, 1, 0)
        self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)
        self.layer3_1x1 = convrelu(256, 256, 1, 0)
        self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(512, 512, 1, 0)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)

        self.conv_original_size0 = convrelu(3, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)

        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, input):
        
        x_original = self.conv_original_size0(input)
        
        x_original = self.conv_original_size1(x_original)
        
        
        layer0 = self.layer0(input)
        
        #save_image(layer0,"input_layer0.png")
        
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)
        
        #
        
        layer4 = self.layer4_1x1(layer4)
        x = self.upsample(layer4)
        print(x.shape)
        save_image(x[1,1,:,:],"upsample_layer.png")
        
        save_image(layer3[1,0,:,:],"layer_3.png")
        save_image(x[1,0,:,:],"layer_3_x.png")
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)

        x = self.upsample(x)
        save_image(layer2[1,0,:,:],"layer_2.png")
        save_image(x[1,0,:,:],"layer_2_x.png")
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        save_image(layer1[1,0,:,:],"layer_1.png")
        save_image(x[1,0,:,:],"layer_1_x.png")
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        save_image(layer0[1,0,:,:],"layer_0.png")
        save_image(x[1,0,:,:],"layer_0_x.png")
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)
        

        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)

        out = self.conv_last(x)
        
        save_image(out[1,0,:,:],"Last_layer_ch1.png")
        save_image(out[1,1,:,:],"Last_layer_ch2.png")
        save_image(out[1,2,:,:],"Last_layer_ch3.png")
        return out
    
#**********************************************************************************************************************************************************
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNetUNet(n_class=3)
summary(model, input_size=(3, 224, 224))
model = model.to(device)


transform = transforms.Compose([ transforms.Resize((224,224), interpolation=2),transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])
trainset=torchvision.datasets.ImageFolder("../data/train", transform=transform, target_transform=None)
trainloader = torch.utils.data.DataLoader(trainset, shuffle = True , batch_size = batch_size , num_workers = 2, drop_last=True)


testset=torchvision.datasets.ImageFolder("../data/test", transform=transform, target_transform=None)
testloader = torch.utils.data.DataLoader(testset, shuffle = True , batch_size = batch_size , num_workers = 2, drop_last=True)

#autoencoder_criterion = nn.MSELoss()
optimizer_ft = optim.Adam(model.parameters(), lr = 1e-2)
#optimizer_ft = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=20, gamma=0.1)

#**********************************************************************************************************************************************************

def calc_loss(pred, target):
    autoencoder_criterion = nn.MSELoss()
    loss = autoencoder_criterion(pred, target)
    #m = nn.Sigmoid()
    #autoencoder_criterion = nn.BCELoss()
    #loss = autoencoder_criterion(m(pred), target)
    return loss

def train_model(model, optimizer, scheduler, num_epochs=25):

    #model.train()  # Set model to training mode
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        since = time.time()
        
        
        #model.train()  # Set model to training mode

        run_loss = 0
        
        for data  in trainloader:
            
            optimizer.zero_grad()
            inputs,_ = data
            inputs = inputs.to(device)
            outputs = model(inputs)
            exit(0)
            loss = calc_loss(outputs, inputs)
            loss.backward(retain_graph=True)
            optimizer.step()
            #exit(0)
            time_elapsed = time.time() - since
            #print("loss=",loss.item())
            run_loss +=  loss.item()*inputs.size(0)
        
        run_loss = run_loss / len(trainset)
        print("scheduler=",scheduler.get_last_lr())
        scheduler.step()
        print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, run_loss))
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))


    for idx in np.arange(3):
        save_image(outputs[idx],"Train_pred_output_image%d.png"%idx)
        save_image(inputs[idx],"Train_pred_input_image%d.png"%idx)


#**********************************************************************************************************************************************************
    
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

num_class = 3
model = ResNetUNet(num_class).to(device)
train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=50)

#**********************************************************************************************************************************************************
model.eval()   # Set model to the evaluation mode

inputs_test, _ = next(iter(testloader))
inputs_test = inputs_test.to(device)
pred = model(inputs_test)
#pred = F.sigmoid(pred)
for idx in np.arange(3):
    save_image(inputs_test[idx],"result_input_image%d.png"%idx)
    save_image(pred[idx],"result_output_image%d.png"%idx)

Could you check, if the model parameters are getting a valid gradient by printing the .grad attributes of them after the first backward call?

Also, why do you need to use retain_graph=True here?

loss.backward(retain_graph=True)

As a general guidance I would try to overfit a small data sample first (e.g. just 10 samples) and make sure the model is able to overfit it.

1 Like

@ptrblck

As per your suggestion, i checked the gradients of the parameters after the loss.backward() step. I get the gradients in the range of 10^-4 to 10^-7. You can find the gradients at the end of this post.

For overfitting the model, I trained the model with exactly 10 images and i dont get the perfect reconstructed output. I have attached the inputs and the corresponding outputs after overfitting(training the model with 10 images). The inputs are normalized and are shown in the 1st and 3rd row. The corresponding outputs are shown in the 2nd and the 4th row.

result_input_image
result_output_image

result_input_image0
result_output_image0

******************************************************************************** Example of gradients obtained*************************************************
Grad_before= None
Grad_before_shape= torch.Size([64, 192, 3, 3])
Grad_after= tensor([[[[ 1.5444e-05, 1.6913e-05, 1.5995e-05],
[ 1.6880e-05, 1.8491e-05, 1.7459e-05],
[ 1.8331e-05, 1.9917e-05, 1.8404e-05]],

     [[ 1.9927e-05,  1.7666e-05,  1.5221e-05],
      [ 1.9537e-05,  1.6954e-05,  1.4170e-05],
      [ 1.9170e-05,  1.6359e-05,  1.3342e-05]],

     [[ 2.6886e-05,  2.8304e-05,  3.0287e-05],
      [ 2.7075e-05,  2.8953e-05,  3.1588e-05],
      [ 2.6412e-05,  2.8855e-05,  3.2180e-05]],

     ...,

     [[ 3.6494e-05,  4.1480e-05,  5.3497e-05],
      [ 3.5902e-05,  3.6543e-05,  4.6513e-05],
      [ 4.2026e-05,  3.5562e-05,  4.6378e-05]],

     [[ 2.1511e-05,  5.2198e-05,  8.9409e-05],
      [ 2.1499e-05,  5.0138e-05,  8.8784e-05],
      [ 2.2240e-05,  4.8883e-05,  8.4382e-05]],

     [[ 8.5568e-07,  3.9607e-08,  9.4165e-07],
      [ 1.3130e-08, -2.6260e-07,  1.0790e-07],
      [ 6.6315e-07, -4.9992e-08,  1.2894e-07]]],


    [[[-2.7049e-05, -2.1350e-05, -1.4579e-05],
      [-2.6697e-05, -2.0302e-05, -1.4146e-05],
      [-2.4956e-05, -1.8778e-05, -1.3871e-05]],

     [[-2.1540e-05, -1.5810e-05, -1.6859e-05],
      [-1.8347e-05, -1.3086e-05, -1.5404e-05],
      [-1.6027e-05, -1.1518e-05, -1.5272e-05]],

     [[-4.0438e-05, -4.8218e-05, -5.6844e-05],
      [-4.0368e-05, -4.9173e-05, -5.8418e-05],
      [-3.7268e-05, -4.7455e-05, -5.8082e-05]],

     ...,

     [[-1.3145e-05, -8.9089e-05, -1.2546e-04],
      [-2.2529e-05, -8.7684e-05, -1.1259e-04],
      [-4.3510e-05, -1.1628e-04, -1.4264e-04]],

     [[-8.6003e-05, -1.5247e-04, -1.5601e-04],
      [-8.3001e-05, -1.3158e-04, -1.4844e-04],
      [-9.7446e-05, -1.2937e-04, -1.4405e-04]],

     [[ 2.2745e-06,  8.6252e-06,  1.6749e-06],
      [ 1.0760e-06,  9.5257e-06,  4.2126e-06],
      [-5.9326e-06,  5.3898e-06, -7.5285e-06]]],


    [[[ 5.7302e-05,  3.6736e-05, -1.1809e-05],
      [ 5.0031e-05,  2.4906e-05, -2.9053e-05],
      [ 5.9241e-05,  2.9926e-05, -2.5830e-05]],

     [[-4.2512e-04, -5.1677e-04, -4.7505e-04],
      [-4.5070e-04, -5.4368e-04, -4.9834e-04],
      [-4.5768e-04, -5.5035e-04, -4.9912e-04]],

     [[-4.5637e-04, -4.6583e-04, -3.4601e-04],
      [-4.6171e-04, -4.5984e-04, -3.3626e-04],
      [-4.8818e-04, -4.7194e-04, -3.3743e-04]],

     ...,

     [[-6.0629e-03, -5.9946e-03, -5.1200e-03],
      [-6.2693e-03, -6.2557e-03, -5.3146e-03],
      [-6.2365e-03, -6.2213e-03, -5.2145e-03]],

     [[-7.2563e-04, -2.4535e-04,  2.9517e-04],
      [-8.3638e-04, -3.7627e-04,  1.9238e-04],
      [-6.2643e-04, -2.4000e-04,  2.5850e-04]],

     [[-2.7501e-03, -2.7126e-03, -2.3582e-03],
      [-2.8471e-03, -2.8696e-03, -2.5841e-03],
      [-2.6671e-03, -2.7258e-03, -2.4540e-03]]],


    ...,


    [[[ 5.3254e-05,  6.5035e-05,  1.0622e-04],
      [ 4.4166e-05,  5.8326e-05,  1.0131e-04],
      [ 3.2802e-05,  4.8616e-05,  9.2495e-05]],

     [[ 3.4394e-03,  3.5002e-03,  3.4879e-03],
      [ 3.4689e-03,  3.5324e-03,  3.5205e-03],
      [ 3.4897e-03,  3.5536e-03,  3.5402e-03]],

     [[ 3.6269e-03,  3.5958e-03,  3.5091e-03],
      [ 3.6331e-03,  3.5956e-03,  3.5011e-03],
      [ 3.6592e-03,  3.6140e-03,  3.5089e-03]],

     ...,

     [[ 6.5372e-02,  6.5775e-02,  6.5144e-02],
      [ 6.5743e-02,  6.6213e-02,  6.5610e-02],
      [ 6.5736e-02,  6.6206e-02,  6.5699e-02]],

     [[ 6.1241e-03,  5.5962e-03,  5.1829e-03],
      [ 6.1115e-03,  5.5773e-03,  5.1191e-03],
      [ 6.0164e-03,  5.5238e-03,  5.0697e-03]],

     [[ 2.9045e-02,  2.9361e-02,  2.9242e-02],
      [ 2.9073e-02,  2.9443e-02,  2.9367e-02],
      [ 2.8875e-02,  2.9244e-02,  2.9191e-02]]],


    [[[-2.7105e-05, -3.4002e-05, -5.2283e-05],
      [-2.6984e-05, -3.5369e-05, -5.5386e-05],
      [-2.6098e-05, -3.6343e-05, -5.6643e-05]],

     [[-1.6817e-04, -1.7624e-04, -1.5094e-04],
      [-1.7449e-04, -1.8054e-04, -1.5067e-04],
      [-1.7860e-04, -1.8413e-04, -1.5044e-04]],

     [[-2.2523e-04, -2.2414e-04, -1.7151e-04],
      [-2.2442e-04, -2.2134e-04, -1.6786e-04],
      [-2.2316e-04, -2.1915e-04, -1.6463e-04]],

     ...,

     [[-1.4295e-03, -1.6756e-03, -1.4744e-03],
      [-1.4044e-03, -1.6606e-03, -1.5211e-03],
      [-1.3774e-03, -1.5886e-03, -1.4934e-03]],

     [[-2.7282e-04, -8.7181e-05, -8.0031e-06],
      [-2.9121e-04, -1.0438e-04, -2.2774e-05],
      [-2.7456e-04, -1.0704e-04, -2.9238e-05]],

     [[-6.0413e-04, -5.5349e-04, -3.9992e-04],
      [-5.8806e-04, -5.5365e-04, -4.3254e-04],
      [-5.6101e-04, -5.3138e-04, -4.4425e-04]]],


    [[[-1.1781e-04, -1.2889e-04, -1.5897e-04],
      [-1.1206e-04, -1.2501e-04, -1.5591e-04],
      [-1.0279e-04, -1.1642e-04, -1.4722e-04]],

     [[-4.2027e-03, -4.2763e-03, -4.2969e-03],
      [-4.2305e-03, -4.3049e-03, -4.3248e-03],
      [-4.2183e-03, -4.2923e-03, -4.3104e-03]],

     [[-4.4182e-03, -4.4096e-03, -4.3487e-03],
      [-4.4423e-03, -4.4278e-03, -4.3601e-03],
      [-4.4392e-03, -4.4195e-03, -4.3446e-03]],

     ...,

     [[-7.9604e-02, -8.0565e-02, -8.0215e-02],
      [-7.9806e-02, -8.0850e-02, -8.0523e-02],
      [-7.9351e-02, -8.0334e-02, -8.0084e-02]],

     [[-7.3640e-03, -6.9757e-03, -6.6576e-03],
      [-7.3242e-03, -6.9142e-03, -6.5543e-03],
      [-7.2372e-03, -6.8390e-03, -6.4863e-03]],

     [[-3.5256e-02, -3.5790e-02, -3.5839e-02],
      [-3.5032e-02, -3.5616e-02, -3.5710e-02],
      [-3.4774e-02, -3.5336e-02, -3.5424e-02]]]], device='cuda:0')

As your model cannot overfit this small data set, you should try to debug it and play around with hyperparameters etc. to make sure it’s able to do so.
E.g. you could remove some layers and work with a very simple model first (even a single conv layer could be a valid experiment).