Thank you for your reply:)
In fact, I think this error will appear in the CycleGAN code since my training structure is similar to it.
(I am now updating the old v0.3.1’s official CycleGAN code to v0.4 to test.)
Here are the optimizers:
optimizer_g = torch.optim.Adam(itertools.chain(netGX.parameters(), netGY.parameters()), lr=g_lr, betas=[beta1, beta2])
optimizer_dx = torch.optim.Adam(netDX.parameters(), lr=d_lr, betas=[beta1, beta2])
optimizer_dy = torch.optim.Adam(netDY.parameters(), lr=d_lr, betas=[beta1, beta2])
And the training part here:
for i, (real_X, real_Y) in enumerate(zip(tr_x_loader, tr_y_loader)):
if (i+1) % log_step == 1:
iter_time = time.time()
# input image data
real_X = real_X.to(device)
real_Y = real_Y.to(device)
# freeze discriminators
switch_gradient([netDX, netDY], mode='off')
optimizer_g.zero_grad()
# train generator netGX ---------------------------------------------------------------------
fake_Y = netGX(real_X)
DY_fake_decision = netDY(fake_Y)
GX_loss = MLE_loss(fake_Y, DY_fake_decision)
G_loss = GX_loss # + GY_loss + CX_loss + CY_loss
G_loss.backward()
optimizer_g.step()
reset_gradient([netGX, netGY])
# ...
The error cames from this line of code:
G_loss.backward()
And the functions “MLE_loss” and “switch_gradient” are shown as follows:
def MLE_loss(x, pred):
return torch.mean(torch.sum(torch.abs(x - pred), 1))
def switch_gradient(net_list, mode='off'):
if mode == 'off': # turn off the gradients
for net_id, net in enumerate(net_list):
for p in net.parameters():
p.requires_grad = False
else: # turn on the gradients
for net_id, net in enumerate(net_list):
for p in net.parameters():
p.requires_grad = True
However, if I change this line of code:
GX_loss = MLE_loss(fake_Y, DY_fake_decision)
to
GX_loss = MLE_loss(fake_Y, DY_fake_decision.detach())
then the error disappeared, but the results are wrong.