Unet not converging

I’m trying to implement the Recurrent Attention U-net model from here. I’ve modified the code a little bit for my purpose. Here is my implementation:

class R2AttU_Net(nn.Module):
    def __init__(self,img_ch=3,output_ch=2,t=2):
        super(R2AttU_Net,self).__init__()
        ## Just basic stuff from the original code
    def encoder(self, x):
        x1 = self.RRCNN1(x)
        x2 = self.Maxpool(x1)
        x2 = self.RRCNN2(x2)
        x3 = self.Maxpool(x2)
        x3 = self.RRCNN3(x3)
        x4 = self.Maxpool(x3)
        x4 = self.RRCNN4(x4)
        x5 = self.Maxpool(x4)
        x5 = self.RRCNN5(x5)
        return x5, x4, x3, x2, x1
    
    def decoder(self, x5, x4, x3, x2, x1):
        d5 = self.Up5(x5)
        x4 = self.Att5(g=d5,x=x4)
        d5 = torch.cat((x4,d5),dim=1)
        d5 = self.Up_RRCNN5(d5)
        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4,x=x3)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_RRCNN4(d4)
        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3,x=x2)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_RRCNN3(d3)
        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2,x=x1)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_RRCNN2(d2)
        d1 = self.Conv_1x1(d2)
        return d1, d2, d3, d4, d5

    def forward(self, x):
        x5, x4, x3, x2, x1 = self.encoder(x)
        d, _, _, _, _ = self.decoder(x5, x4, x3, x2, x1)
        return d

There is nothing exceptional in my training process, just basic loss calculation and then loss.backwards() and optimizer.step() and optimizer.zero_grad() as usual. But it’s not working. Loss reduces for a while and then starts to oscillate. I’ve tried with the same model and same set of paramaters(learning rate, gradient accumulations step) without using encoder() and decoder() functions and it works. What am I doing wrong?

As a quick test, I would recommend to try to overfit a small data sample (e.g. just 10 samples) and make sure your model is able to overfit it.
If that’s not the case, you would need to play around with some hyperparameters (e.g. learning rate).

Sorry for the delayed response. I followed your suggestion(ran on a small subset of 100 data, increased and decreased learning rate) and ran the script several times. My model does not even overfit on train data. Seems like backwards() not working. Although when I switch to the actual implementation, it works again.

What do you mean by “seems like backwards() [is] not working”? Is your model not getting any gradients or are you seeing any other issues?

If the original implementation is working, you could compare both models and try to debug it by adding your changes step by step and make sure the model is still working.

1 Like

Sorry for the really long delay in my response. I got carried away in some other projects. However, here is my customized loss function for the training pipeline:

def dice_loss(pred, target, threshold = 0.5, smooth = 1e-6):
    pred = torch.sigmoid(pred)
    pred = (pred>threshold).float()
    num = pred.size(0)
    m1 = pred.view(num, -1)  # Flatten
    m2 = target.view(num, -1)  # Flatten
    intersection = (m1 * m2).sum()
    dice_score = (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)
    return 1 - dice_score

I think I’m hurting the gradients by applying the threshold. Let me know if am right or not. I’ve removed the threshold part and now the model seems to converge.

Yes, this operation pred>threshold won’t get any gradients as you can see in the missing .grad_fn in the output, as this operation is not differentiable.