You literally write a static method jvp
that takes grad_in
s for all its arguments and returns the matching grad_out
s (and rename backward
to vjp
if you want to be cool).
I haven’t really found what is a good way to save inputs to the function for using them in the jvp
- save_for_backwards
does not work in forward mode, maybe @alband knows.
(NOTE: If you’re reading this and it’s 2022 or later, quite likely something has been implemented for it.)
class MyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, f1, f2, mul=True):
assert not (f1.is_complex() or f2.is_complex()), "Complex not supported"
ctx.mul = mul
if mul:
ctx.save_for_backward(f1, f2)
ctx.f1 = f1
ctx.f2 = f2
return f1 * f2
else:
return f1 + f2
@staticmethod
def vjp(ctx, grad_out):
if ctx.mul:
f1, f2 = ctx.saved_tensors
grad_f1 = (grad_out * f2).sum_to_size(f1.shape)
grad_f2 = (f1 * grad_out).sum_to_size(f2.shape)
return grad_f1, grad_f2, None
else:
return grad_out, grad_out, None
@staticmethod
def jvp(ctx, grad_f1, grad_f2, _1):
if ctx.mul:
f1, f2 = ctx.f1, ctx.f2
f1, _ = torch.autograd.forward_ad.unpack_dual(f1)
f2, _ = torch.autograd.forward_ad.unpack_dual(f2)
return grad_f1 * f2 + f1 * grad_f2
else:
return grad_f1 + grad_f2
f1 = torch.randn(5, 1, requires_grad=True, dtype=torch.float64)
f2 = torch.randn(1, 5, requires_grad=True, dtype=torch.float64)
torch.autograd.gradcheck(lambda f1, f2: MyFunction.apply(f1, f2, False), (f1, f2), check_forward_ad=True)
torch.autograd.gradcheck(lambda f1, f2: MyFunction.apply(f1, f2, True), (f1, f2), check_forward_ad=True)
Best regards
Thomas