I need to do something like this:
class MyOp(torch.autograd.Function):
@staticmethod
def forward(ctx, net1, net2, x):
ctx.net1 = net1
ctx.net2 = net2
ctx.save_for_backward(x)
return net1(x)
@staticmethod
def backward(ctx, grad):
net1 = ctx.net1
net2 = ctx.net2
x = ctx.saved_tensors
# disable backward for parameters in net2, because I only need the gradient for x by net2.
for params in net2.parameters():
params.requires_grad_(False)
with torch.enable_grad():
y = net2(x)
y.backward(torch.ones_like(x).to(x))
gradx = x.grad.clone().detach()
# enable backward for net2, because it needs to be used in other computations.
for params in net2.parameters():
params.requires_grad_(True)
return (None, None, gradx)
This code works well for single-GPU. However, when I use DataParallel with Multi-GPUs, the gradient is wrong.
I guess maybe it is because there is no lock for multi-processes and there are some gradients backwarded to parameters in net2. How can I correct my code for DataPrallel models?