How to extend torch.autograd by Function

HI, I find there are 2 ways in pytorch to extend torch.autograd by creating subclasses of torch.autograd.Function.

<1> this function as an example

from torch.autograd.function import Function

class FuncA(Function):

    def __init__(self):
        super(FuncA, self).__init__()

    def forward(self, input):
        output = input.new()
        # calculation of output here
        self.save_for_backward(input)
        return output

    def backward(self, grad_output):
        input = self.saved_tensors
        grad_input = grad_output.new()
        # calculation of input gradient
        return grad_input

In this example, input, output, grad_output, grad_input are torch.Tensor

<2> this function as an example

from torch.autograd.function import Function

class FuncB(Function):

    @staticmethod
    def forward(ctx, input):
        output = input.new()
        # calculation of output here
        ctx.save_for_backward(input)
        return output

    @staticmethod
    def backward(self, grad_output):
        input = ctx.saved_variables
        grad_input = FuncBBackward.apply(grad_output)
        return grad_input

In this example, input, output are torch.Tensor while grad_output, grad_input are torch.autograd.variable.Variable. FuncBBackward is also a subclass of torch.autograd.Function.

QUESTION

What is the difference of this two ways?
Why @staticmethod for the second one, but not for the first one?
Why grad_output has different data type in the two methods?
Is there any difference between

self.save_for_backward(input)
input = self.saved_tensors

and

ctx.save_for_backward(input)
input = ctx.saved_variables

?

Thanks~

1 Like

This is old-style (the first) vs new-style (the second) autograd. The latter was introduced in 0.2 and is required e.g. for second derivatives. Note that it has a context ctx that takes the context data rather than using an instance.

An important difference you did not mention is the use of the functions. Use FuncB.apply(input) for new style.

Best regards

Thomas

1 Like

@tom Thanks for your reply.
So in the new-style, the input/output of forward function are still tensor but the grad_output of backward function is Variable which enable the calculation the second derivatives?

Yes, exactly. And you save tensors but use saved_variables just as you described.

Best regards

Thomas