How do i reimplement the following logic in torch 1.7+?

        # loss1, loss2, loss3 are depend on both modelA and modelB
        # first we want to use loss1, loss2 to update modelA
        # then use loss2, loss3 to update modelB
        modelB.zero_grad()
        modelA.zero_grad()
        lossA = loss1 + loss2
        lossA.backward(retain_graph=True)
        modelA_optimizer.step()

        modelB.zero_grad()
        lossB = loss2 + loss3
        lossB.backward()
        modelB_optimizer.step()

Apparently the code above works for torch1.0.0; but fails in torch1.7.0.

The logic is flawed and PyTorch now properly raises an error, since you are using stale forward activations in the second backward call as described here.

Sorry for the late reply, i wonder if something like this works:

gradA, gradB = {}, {}

lossA.backward(retain_graph=True)

for name, param in modelA.named_parameters():
    gradA[name] = param.grad.clone().detach()
for name, param in modelB.named_parameters():
    gradB[name] = param.grad.clone().detach()

lossB.backward()
# restore modelA's grad
for name, param in modelA.named_parameters():
    param.grad = gradA[name]
for name, param in modelB.named_parameters():
    param.grad -= gradB[name]
# update model afterwards
    

I don’t think it would help, as the intermediate forward activations are stale compared to the already updated parameters as explained in the linked post.