How to temporarily detach parameters when using DataParallel?

I need to do something like this:

class MyOp(torch.autograd.Function):
        @staticmethod
        def forward(ctx, net1, net2, x):
            ctx.net1 = net1
            ctx.net2 = net2
            ctx.save_for_backward(x)
            return net1(x)

        @staticmethod
        def backward(ctx, grad):
            net1 = ctx.net1
            net2 = ctx.net2
            x = ctx.saved_tensors
            # disable backward for parameters in net2, because I only need the gradient for x by net2.
            for params in net2.parameters():
                params.requires_grad_(False)
            with torch.enable_grad():
                y = net2(x)
             y.backward(torch.ones_like(x).to(x))
             gradx = x.grad.clone().detach()
             # enable backward for net2, because it needs to be used in other computations.
             for params in net2.parameters():
                 params.requires_grad_(True)
             return (None, None, gradx)

This code works well for single-GPU. However, when I use DataParallel with Multi-GPUs, the gradient is wrong.

I guess maybe it is because there is no lock for multi-processes and there are some gradients backwarded to parameters in net2. How can I correct my code for DataPrallel models?

My guess the reason why it doesn’t work, is you can no longer get parameters on DataParallel replica.
One workaround, (my guess), is to use torch.autograd.grad instead of backward.
you can do:

gradx = torch.autograd.grad(y, x, torch.ones_like(x).to(x))[0]
2 Likes

Can you set grads to True for net2’s parameters before you start the forward and then set them to false after you are done with the forward? This way the grads should be False for the entire backward pass despite concurrent execution of the backward pass across multiple GPUs.

Thanks! I’ve tried your suggestion and it works!

1 Like