Ahhh I also just realized that I wrote the comment above, I didn’t see that you had three losses to backward in total, you’ll need to make sure all the optimizer updates are done at the very end.
So, modifying the code in the a way is probably the preferred solution here:
# +----------------------------------+
# | discriminator loss |
# +----------------------------------+
d = net_D(x)
d_hat = net_D(x_hat)
d_p = net_D(x_p)
real_label = Variable(Tensor(x.size(0), 1).fill_(1.0), requires_grad=False).to(device)
fake_label = Variable(Tensor(x.size(0), 1).fill_(0.0), requires_grad=False).to(device)
loss_D_real = adversarial_loss(d, real_label)
loss_D_fake = adversarial_loss(d_hat, fake_label)
loss_D_prior = adversarial_loss(d_p, fake_label)
loss_gan = loss_D_real + loss_D_fake + loss_D_prior
optimizer_D.zero_grad()
loss_gan.backward(retain_graph=True)
# +----------------------------+
# | decoder loss |
# +----------------------------+
rec_loss = ((net_D(x_hat) - net_D(x)) ** 2).mean()
print(rec_loss)
loss_dec = gamma * rec_loss - loss_gan
optimizer_d.zero_grad()
loss_dec.backward(retain_graph=True)
# +----------------------------+
# | encoder loss |
# +----------------------------+
loss_enc = elbo_loss + rec_loss
optimizer_e.zero_grad()
loss_enc.backward()#retain_graph=True)
optimizer_D.step()
optimizer_d.step()
optimizer_e.step()