I used 2 methods to build a U-Net in Pytorch. But the performance of the 2 nets are different. I want to know if the 2 methods are the same or not?
First method to build U-Net:
class UNet_1(nn.Module):
def __init__(self):
......
def forward(self, x):
x0 = self.conv_0(x)
x1 = self.down_1(x0)
x2 = self.down_2(x1)
x3 = self.down_3(x2)
x4 = self.down_4(x3)
u4 = self.upcat_4(x4, x3)
u3 = self.upcat_3(u4, x2)
u2 = self.upcat_2(u3, x1)
u1 = self.upcat_1(u2, x0)
logits = self.final_conv(u1)
return logits
net = UNet_1()
Second method to build U-Net:
class Encoder(nn.Module):
def __init__(self):
......
def forward(self, x):
x0 = self.conv_0(x)
x1 = self.down_1(x0)
x2 = self.down_2(x1)
x3 = self.down_3(x2)
x4 = self.down_4(x3)
return x0, x1, x2, x3, x4
class Decoder(nn.Module):
def __init__(self):
......
def forward(self, x0, x1, x2, x3, x4):
u4 = self.upcat_4(x4, x3)
u3 = self.upcat_3(u4, x2)
u2 = self.upcat_2(u3, x1)
u1 = self.upcat_1(u2, x0)
logits = self.final_conv(u1)
return logits
class UNet_2(nn.Module):
def __init__(self, enc, dec):
super().__init__()
self.enc = enc
self.dec = dec
def forward(self, x):
x0, x1, x2, x3, x4 = self.enc(x)
out = self.dec(x4)
return out
enc = Encoder()
dec = Decoder()
unet= UNet_2(enc, dec)
The training would be like:
opt = torch.optim.Adam(unet.parameters(), lr=1e-4)
x, y = dataloader()
pred = unet(x)
loss = cretera(pred, y)
opt.zero_grad()
loss.backward()
opt.step()
I found the second method cannot work normally according to the results. Because during the training of the second method I found the training loss decreased in the first several epochs, but increased again afterwards. finally the validation dice result is 0. I donot know why. Maybe because of net.parameters()
? Any ideads would be appreciated.