HI, I find there are 2 ways in pytorch to extend torch.autograd
by creating subclasses of torch.autograd.Function
.
<1> this function as an example
from torch.autograd.function import Function
class FuncA(Function):
def __init__(self):
super(FuncA, self).__init__()
def forward(self, input):
output = input.new()
# calculation of output here
self.save_for_backward(input)
return output
def backward(self, grad_output):
input = self.saved_tensors
grad_input = grad_output.new()
# calculation of input gradient
return grad_input
In this example, input
, output
, grad_output
, grad_input
are torch.Tensor
<2> this function as an example
from torch.autograd.function import Function
class FuncB(Function):
@staticmethod
def forward(ctx, input):
output = input.new()
# calculation of output here
ctx.save_for_backward(input)
return output
@staticmethod
def backward(self, grad_output):
input = ctx.saved_variables
grad_input = FuncBBackward.apply(grad_output)
return grad_input
In this example, input
, output
are torch.Tensor
while grad_output
, grad_input
are torch.autograd.variable.Variable
. FuncBBackward
is also a subclass of torch.autograd.Function
.
QUESTION
What is the difference of this two ways?
Why @staticmethod
for the second one, but not for the first one?
Why grad_output
has different data type in the two methods?
Is there any difference between
self.save_for_backward(input)
input = self.saved_tensors
and
ctx.save_for_backward(input)
input = ctx.saved_variables
?
Thanks~