Custom module with partial autograd

Consider the following extension of torch.autograd.function:

class MorphingLayer(Function):

    @staticmethod
    def forward(ctx, input, idx):
        
        ctx.save_for_backward(input, idx)
        #implementation of forward pass

    return output1, output2

Assume that the gradient w.r.t. input can be obtained using autograd, but the gradient w.r.t. idx must be implemented manually. How would I achieve that in the implementation of the backward pass, i.e., how can I access input.grad during the backward pass?

@staticmethod
def backward(ctx, gradIdx1,gradIdx2):
    input, idx = ctx.saved_tensors

    #compute gradIdx manually
    gradIdx = (manually computed gradient)

    #obtain gradInput from autograd 
    gradInput = (obtain input.grad)

    return gradInput, gradIdx
1 Like

This should give you a start for the automatically calculated gradient bit:

class MyFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a_):
        with torch.enable_grad():
            a = a_.detach().requires_grad_()
            res = a**2
        ctx.save_for_backward(a, res)
        return res.detach()
    @staticmethod
    def backward(ctx, grad_out):
        a, res = ctx.saved_tensors
        gr, = torch.autograd.grad(res, a, grad_out, retain_graph=True)
        return gr

x = torch.randn(2,2, requires_grad=True, dtype=torch.double)
print(torch.autograd.gradcheck(MyFn.apply, (x,)))

I must admit I have no idea why you’d need to specify “retain_graph”, it’s just the empirical observation that without it, it doesn’t run.

Best regards

Thomas

2 Likes

That’s extremely helpful, thanks! Can someone explain what’s going on under the hood (why detach? why retain_graph?) and if it’s safe to combine autograd and a custom backward pass in this way. Does it break other functionality?

@tom, I think that is really a nice approach. I’d like to explain it a bit:

In PyTorch autograd usually automatically computes the gradients of all operations, as long as requires_grad is set to True. If you however need operations that are not natively supported by PyTorch’s autograd, you can manually define the function and how to compute its gradients. Therefore autograd is by default turned off in forward and backward of subclasses of torch.autograd.Function. To turn it back on manually, tom used

with torch.enable_grad():

Within this block gradients are 1) automatically calculated and 2) passed backwards through the graph. You want 1), but I think 2) is not a good idea within forward, because you are expected to explicitly pass the gradient backwards in backward.

To prevent the gradient from automatically flowing backward, you need to detach the input from the graph (a_.detach()). I guess from then on you could let the gradient be calculated implicitly by returning res (instead of res.detach()) and get the gradient by gr = a.grad (perhaps you would have to set retain_graph=True before calculating res), but the more explicit way is to also detach the result and calculate the gradient explicitly with

gr, = torch.autograd.grad(res, a, grad_out, retain_graph=True)

Concerning retain_graph I am a bit puzzled, too.

Apart from that, I think this is a good approach. I cannot think of big impacts on other aspects, so this should be safe. The only thing I can think of is, that it forces the calculation of gradients (within this function) even when running the defined function in a

with torch.no_grad():

block. If the computation of the gradients of your function are complex or require to store lots of intermediate gradients, that might cause GPU memory or runtime issues.

One thing you could do, is to check if the input (a_) requires gradients and use torch.no_grad() instead of torch.enable_grad() or skip the requires_grad_() part to prevent the calculation of some unnecessary gradients.

2 Likes

Nice explanation! @Florian_1990

It might be easiest to have a wrapper (you wouln’t use MyFn.apply as a name, anyway) that checks whether torch.is_grad_enables() and whether the inputs need the gradient. And pass that as a boolean argument to the function. At least that’s what I usually do for extensions.

Best regards

Thomas