Based on your code snippet you might be running into this issue raised by using stale forward activations.
In particular, unet_loss
is attached to dis_loss
. Calling dis_loss.backward()
will create gradients for a parameter set, which is then updated via dis_optimizer.step()
. Calling unet_loss.backward()
afterwards will fail, since the forward activations are now stale as the corresponding parameters were already updated.
1 Like