NotImplementedError: You must implement the jvp function for custom autograd.Function to use it with forward mode AD

You literally write a static method jvp that takes grad_ins for all its arguments and returns the matching grad_outs (and rename backward to vjp if you want to be cool).

I haven’t really found what is a good way to save inputs to the function for using them in the jvp - save_for_backwards does not work in forward mode, maybe @alband knows.
(NOTE: If you’re reading this and it’s 2022 or later, quite likely something has been implemented for it.)

class MyFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, f1, f2, mul=True):
        assert not (f1.is_complex() or f2.is_complex()), "Complex not supported"
        ctx.mul = mul
        if mul:
            ctx.save_for_backward(f1, f2)
            ctx.f1 = f1
            ctx.f2 = f2
            return f1 * f2
        else:
            return f1 + f2

    @staticmethod
    def vjp(ctx, grad_out):
        if ctx.mul:
            f1, f2 = ctx.saved_tensors
            grad_f1 = (grad_out * f2).sum_to_size(f1.shape)
            grad_f2 = (f1 * grad_out).sum_to_size(f2.shape)
            return grad_f1, grad_f2, None
        else:
            return grad_out, grad_out, None

    @staticmethod
    def jvp(ctx, grad_f1, grad_f2, _1):
        if ctx.mul:
            f1, f2 = ctx.f1, ctx.f2
            f1, _ = torch.autograd.forward_ad.unpack_dual(f1)
            f2, _ = torch.autograd.forward_ad.unpack_dual(f2)
            return grad_f1 * f2 + f1 * grad_f2
        else:
            return grad_f1 + grad_f2

f1 = torch.randn(5, 1, requires_grad=True, dtype=torch.float64)
f2 = torch.randn(1, 5, requires_grad=True, dtype=torch.float64)
torch.autograd.gradcheck(lambda f1, f2: MyFunction.apply(f1, f2, False), (f1, f2), check_forward_ad=True)
torch.autograd.gradcheck(lambda f1, f2: MyFunction.apply(f1, f2, True), (f1, f2), check_forward_ad=True)

Best regards

Thomas

1 Like