I met the error when run the code below. (The code actually does nothing but to reproduce the error.) The modules used when defining model come from github.
If the input of discriminator is detached, the error disappears. The error may result from update the parameter of encoder in the same iteration.
Can I update the encoder sereral times? If not, should I freeze the parameters of UNet when optimize discriminator?
# define model
class discriminator(nn.Module):
def __init__(self, in_ch=1024) -> None:
super().__init__()
self.inc = (DoubleConv(in_ch, 64))
self.e1 = Down(64, 128)
self.out = nn.Conv2d(128, 1, 3, 2, 1)
def forward(self, x):
x = self.inc(x)
x = self.e1(x)
x = self.out(x)
return x
class UNet(nn.Module):
def __init__(self, n_channels, n_classes=2, bilinear=False):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = (DoubleConv(n_channels, 64))
self.down1 = (Down(64, 128))
self.down2 = (Down(128, 256))
self.down3 = (Down(256, 512))
factor = 2 if bilinear else 1
self.down4 = (Down(512, 1024 // factor))
self.up1 = (Up(1024, 512 // factor, bilinear))
self.up2 = (Up(512, 256 // factor, bilinear))
self.up3 = (Up(256, 128 // factor, bilinear))
self.up4 = (Up(128, 64, bilinear))
self.outc = (OutConv(64, n_classes))
self.encoder = nn.Sequential(self.inc, self.down1, self.down2, self.down3, self.down4)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return {'logits': logits, 'mid': x5}
# run
def main():
model = UNet(1).cuda()
D1 = discriminator().cuda()
D2 = discriminator().cuda()
optM = optim.Adam(model.parameters(), lr=1e-3)
optD1 = optim.Adam(list(D1.parameters()) + list(model.encoder.parameters()), lr=1e-3)
optD2 = optim.Adam(list(D2.parameters()) + list(model.encoder.parameters()), lr=1e-3)
for i in range(10):
print("epoch: ", i)
image = torch.randn(4, 1, 256, 256).cuda()
out = model(image)
f = out['mid']
# detached "f" works.
# d1 = D1(f.detach() * 0.99)
# d2 = D2(f.detach() * 1.01)
d1 = D1(f * 0.99)
d2 = D2(f * 1.01)
pred = nn.Softmax(1)(out['logits'])
loss = -torch.log(pred).mean()
ld = torch.abs(d1 - d2).mean()
optM.zero_grad()
optD1.zero_grad()
optD2.zero_grad()
loss.backward(retain_graph=True)
optM.step()
ld.backward(retain_graph=True) # error here
optD1.step()
optD2.step()
return