How can I call the backward function of operations in torch.nn.functional

We know that we can implement our own custom autograd functions by subclassing torch.autograd.Function and implementing the forward and backward passes which operate on Tensors. As shown in https://pytorch.org/docs/stable/autograd.html#function

class Exp(Function):

    @staticmethod
    def forward(ctx, i):
        result = i.exp()
        ctx.save_for_backward(result)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        return grad_output * result

The question is how can I call the backward function of operations in torch.nn.functional? Is this possible? And what is the simplest way?
For example, I want to call the backward function of torch.nn.functional.softmax and get the Jacobian matrix. I have tried torch.autograd.functional.jacobian, but it is experimental now and gets a wrong result when I apply it to softmax and a 3D tensor. Is there an existing backward function that I can directly call?

Hi,

The functional methods in torch.nn are not all elementary functions for the autograd. So there isn’t a single backward function to call for them.
What you can do though is get the gradient with the regular autograd:

# inp that requires_grad=True and grad_output that match what you want to compute.
# If you want a full jacobian, you will need a for loop to reconstruct it line by line.
out = functional.softmax(inp)
grad = torch.autograd.grad(out, inp, grad_output)

What is the issue you have with functional.jacobian? What do you mean by wrong result?

1 Like

Hi albanD,
Thank you for your answer and torch.autograd.grad indeed works for me!
I checked the docs of functional.jacobian again, and I found that I misunderstood the use of it. After I change the code, the result of functional.jacobian is all good now.

2 Likes

Hi, how to calculate the backward gradient of the softmax operation by pytorch operations? (not autograd)? Or is there a quick operator for it?

There is no operation that you can call directly from python no.