Based on your code snippet you might be running into this issue raised by using stale forward activations.
In particular, unet_loss is attached to dis_loss. Calling dis_loss.backward() will create gradients for a parameter set, which is then updated via dis_optimizer.step(). Calling unet_loss.backward() afterwards will fail, since the forward activations are now stale as the corresponding parameters were already updated.
1 Like