DistributedDataParallel with multiple losses, freezing some parameters at a time

Networks G and E are both used to calculate loss_A and loss_B.

I want to minimize loss_A by changing only G.parameters(), and minimize loss_B by changing only E. parameters().

The following works in serial mode and also with DataParallel.

for param in E.parameters():
    param.requires_grad = False

loss_A = calculate_loss_A_using_G_and_E()
loss_A.backward(retain_graph=True)

for param in E.parameters():
    param.requires_grad = True
for param in G.parameters():
    param.requires_grad = False

loss_B = calculate_loss_B_using_G_and_E()
loss_B.backward()

for param in G.parameters():
    param.requires_grad = True

However, when using DistributedDataParallel, the code above raises the RuntimeError: Expected to mark a variable ready only once. I suppose that this happens because on the first call to .backward E.parameters() are marked as ready even if requires_grad is False.

What can be done to use DistributedDataParallel in this scenario?