Using AD inside a custom defined torch.autograd.Function

Hi – so I’m defining a torch.autograd.Function.

class MyFunction(torch.autograd.Function):
  @staticmethod
  def forward(ctx, input1, input2):
    # some computation
    ctx.save_for_backward(....)
    return loss
  
  @staticmethod
  def backward(ctx, grad_output):
    # some computation
    return grad1, grad2

From what I understand, anything that happens to input1 inside of forward is not tracked and the gradient applied is whatever is returned in backward. (For example, if loss = input1 * 2, but grad1 = grad_output * 10, the computation graph won’t care about the first statement and treat the gradient as grad_output * 10). Is my understanding correct? Does torch treat the tensor differently because the operation happens inside a torch.autograd.Function rather than a module?

Assuming this understanding is correct, is there a way to combine both behaviors? In particular, suppose my forward function looks something like this:

# inside class MyFunction(torch.autograd.Function)
def forward(ctx, input1, input2):
  a = some_computation_a(input1)
  b = some_computation_b(input1)
  return a + b

All of the steps of some_computation_a aren’t directly differentiable so I want to supply the gradient to those operations myself, but all of the operations in some_computation_b are, so I’d like torch to track those operations automatically. Is this possible to do?

Thanks,
James

1 Like

Hi,

Is my understanding correct?

Yes you are absolutely correct!

Does torch treat the tensor differently because the operation happens inside a torch.autograd.Function rather than a module?

No, the Tensors are treated the same way, but everything runs with autograd disable diring the forward.

Is this possible to do?

Of course!
I think the simplest solution is to (and is also a general guideline) keep your custom Function as small as possible: It will make the gradients easier to derive for you.
In particular, you can have a custom function that handles fw/bw for some_computation_a only and then combine it is a regular python function:

def my_fn(input1, input2):
    return SomeComputationA.apply(input1) + some_computation_b(input2)

Does that work for you?

2 Likes

Yes I think so – but I’ll update this if for some reason it does not.

1 Like

Hi, Mr. albanD,

Can implement a learnable parameter by torch. autograd. Function? Like PRelu activation function implemented by torch. autograd. Function.

Thank you for your time

Hi,

I would recommend using a nn.Module for this. Where you can register your learnable parameters as nn.Parameter so that they are properly detected by other nn constructs.
I would recommend using a custom autograd.Function only if you need to work with something that is non-differentiable or you want custom computations (beyond just gradient) to happen during the backward.

1 Like