Networks G and E are both used to calculate loss_A and loss_B.
I want to minimize loss_A by changing only G.parameters(), and minimize loss_B by changing only E. parameters().
The following works in serial mode and also with DataParallel.
for param in E.parameters():
param.requires_grad = False
loss_A = calculate_loss_A_using_G_and_E()
loss_A.backward(retain_graph=True)
for param in E.parameters():
param.requires_grad = True
for param in G.parameters():
param.requires_grad = False
loss_B = calculate_loss_B_using_G_and_E()
loss_B.backward()
for param in G.parameters():
param.requires_grad = True
However, when using DistributedDataParallel, the code above raises the RuntimeError: Expected to mark a variable ready only once. I suppose that this happens because on the first call to .backward E.parameters() are marked as ready even if requires_grad is False.
What can be done to use DistributedDataParallel in this scenario?