I’m trying to implement the Recurrent Attention U-net model from here. I’ve modified the code a little bit for my purpose. Here is my implementation:
class R2AttU_Net(nn.Module):
def __init__(self,img_ch=3,output_ch=2,t=2):
super(R2AttU_Net,self).__init__()
## Just basic stuff from the original code
def encoder(self, x):
x1 = self.RRCNN1(x)
x2 = self.Maxpool(x1)
x2 = self.RRCNN2(x2)
x3 = self.Maxpool(x2)
x3 = self.RRCNN3(x3)
x4 = self.Maxpool(x3)
x4 = self.RRCNN4(x4)
x5 = self.Maxpool(x4)
x5 = self.RRCNN5(x5)
return x5, x4, x3, x2, x1
def decoder(self, x5, x4, x3, x2, x1):
d5 = self.Up5(x5)
x4 = self.Att5(g=d5,x=x4)
d5 = torch.cat((x4,d5),dim=1)
d5 = self.Up_RRCNN5(d5)
d4 = self.Up4(d5)
x3 = self.Att4(g=d4,x=x3)
d4 = torch.cat((x3,d4),dim=1)
d4 = self.Up_RRCNN4(d4)
d3 = self.Up3(d4)
x2 = self.Att3(g=d3,x=x2)
d3 = torch.cat((x2,d3),dim=1)
d3 = self.Up_RRCNN3(d3)
d2 = self.Up2(d3)
x1 = self.Att2(g=d2,x=x1)
d2 = torch.cat((x1,d2),dim=1)
d2 = self.Up_RRCNN2(d2)
d1 = self.Conv_1x1(d2)
return d1, d2, d3, d4, d5
def forward(self, x):
x5, x4, x3, x2, x1 = self.encoder(x)
d, _, _, _, _ = self.decoder(x5, x4, x3, x2, x1)
return d
There is nothing exceptional in my training process, just basic loss calculation and then loss.backwards()
and optimizer.step()
and optimizer.zero_grad()
as usual. But it’s not working. Loss reduces for a while and then starts to oscillate. I’ve tried with the same model and same set of paramaters(learning rate, gradient accumulations step) without using encoder()
and decoder()
functions and it works. What am I doing wrong?