import torch
from torch.autograd import Function
class ExpTest(Function):
@staticmethod
def forward(ctx, i):
print(ctx)
result = i.exp()
return result
x=torch.tensor([1,2,3.], requires_grad=True)
output = ExpTest.apply(x)
print(output.grad_fn)
# <torch.autograd.function.ExpTestBackward object at 0x7f0f376d84f0>
# <torch.autograd.function.ExpTestBackward object at 0x7f0f376d84f0>
As shown above, looks like the apply
method first instantiated an ExpTestBackward
object, then pass it to forward. but where is this apply
defined? There is no such method in torch.autograd.function.py file