NotImplementedError: You must implement the jvp function for custom autograd.Function to use it with forward mode AD

TL;DR - How do you implement a custom jvp method for a custom.autograd.Function?

Hi All,

I’ve been trying to make a forward-over-reverse function to efficiently compute the Laplacian of a given function, and I’ve been expanding upon what was discussed here

I understand that using forward-over-reverse isn’t vectorized so I won’t get the ideal speed-up but I just wanted to test it out anyway. My current version of the function is here,

def laplacian_from_log_foward_reverse(func, xs):
  with fwAD.dual_level():
    jacobian = torch.zeros(*xs.shape, device=xs.device, dtype=xs.dtype)
    laplacian = torch.zeros(*xs.shape, device=xs.device, dtype=xs.dtype)
    for i in range(xs.shape[-1]):
      tangent = torch.zeros_like(xs) #was zeros_like?
      tangent[:,i] = 1 #mark the index for forward-ad 

      dual_in = fwAD.make_dual(xs, tangent)
      dual_out = func(dual_in)
      primal_out, tangent_out = fwAD.unpack_dual(dual_out) 

      jacobian[:,i] = tangent_out

      out = torch.autograd.grad(tangent_out, xs, torch.ones_like(tangent_out), retain_graph=True, create_graph=True)[0]
      
      laplacian[:,i] = out[:,i]
    return torch.sum(laplacian + jacobian.pow(2), dim=-1)

When I pass in my function (which is a nn.Module) it has a custom.autograd.Function within it and the function fails. I know the function works, as testing it on,

def func(x):
  return x.pow(2).sum(dim=-1)

works completely fine and matches reverse-over-reverse methods, albeit slower. But how do you define a custom jvp function so that forward_ad can differentiate the function?

Any help would be greatly appreciated! :slight_smile:

You literally write a static method jvp that takes grad_ins for all its arguments and returns the matching grad_outs (and rename backward to vjp if you want to be cool).

I haven’t really found what is a good way to save inputs to the function for using them in the jvp - save_for_backwards does not work in forward mode, maybe @alband knows.
(NOTE: If you’re reading this and it’s 2022 or later, quite likely something has been implemented for it.)

class MyFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, f1, f2, mul=True):
        assert not (f1.is_complex() or f2.is_complex()), "Complex not supported"
        ctx.mul = mul
        if mul:
            ctx.save_for_backward(f1, f2)
            ctx.f1 = f1
            ctx.f2 = f2
            return f1 * f2
        else:
            return f1 + f2

    @staticmethod
    def vjp(ctx, grad_out):
        if ctx.mul:
            f1, f2 = ctx.saved_tensors
            grad_f1 = (grad_out * f2).sum_to_size(f1.shape)
            grad_f2 = (f1 * grad_out).sum_to_size(f2.shape)
            return grad_f1, grad_f2, None
        else:
            return grad_out, grad_out, None

    @staticmethod
    def jvp(ctx, grad_f1, grad_f2, _1):
        if ctx.mul:
            f1, f2 = ctx.f1, ctx.f2
            f1, _ = torch.autograd.forward_ad.unpack_dual(f1)
            f2, _ = torch.autograd.forward_ad.unpack_dual(f2)
            return grad_f1 * f2 + f1 * grad_f2
        else:
            return grad_f1 + grad_f2

f1 = torch.randn(5, 1, requires_grad=True, dtype=torch.float64)
f2 = torch.randn(1, 5, requires_grad=True, dtype=torch.float64)
torch.autograd.gradcheck(lambda f1, f2: MyFunction.apply(f1, f2, False), (f1, f2), check_forward_ad=True)
torch.autograd.gradcheck(lambda f1, f2: MyFunction.apply(f1, f2, True), (f1, f2), check_forward_ad=True)

Best regards

Thomas

1 Like

Hi @tom! Thanks for the quick response!

When it comes to actually write what’s in the methods of jvp and vjp are there any examples?

For example, here’s an example for f(x) = x^3 (How to customize the double backward? - #4 by Naruto-Sasuke) where you’d just need to define the backward as grad_out * 3 * x**2. Which is in general just the derivative of the output of the layer w.r.t its input, which you then multiple by the grad_out term to get the gradient of the loss w.r.t the input.

Also, I’ve never seen the ctx.mul before, so I assume that’s a new feature in master?

Thank you for the help, it’s greatly apprecitated! :slight_smile:

The assignments to ctx are made up and not something PyTorch provides or keeps you from doing. For input and output tensors, it is recommended to use save_for_backward to avoid memory ycles and to check for inplace, but for other purposes we assign to ctx. I don’t know if there is a “save for jvp”, I have not seen it yet.

Best regards

Thomas

ahhh fair enough, that makes more sense!

But whether it’s jvp or vjp the actual formula within that method is just the derivative of the output of the module w.r.t its input? Is that correct? Or is there a difference when using forward-ad?

Thank you for the help! :slight_smile:

For scalar input and output the two look the same, but once you leave that, they are quite different (because matrix multiplication is not commutative):

  • jvp multiplies the Jacobian of the forward to the left of a vector. This allows to compute the gradient of a scalar function (the direction of steepest descent) with backprop,
  • vjp multiplies the Jacobian of the forward to the right of a vector. This allows to compute the directional derivative of a function (but you have to give the direction at the start).

Best regards

Thomas

1 Like

And, the vector here is the torch.ones_like(out)? and Jacobian is the derivative of the forward w.r.t whatever you’re differentiating? Is that correct?

Are there any textbooks (or other material) that would explain this is more finer detail?

Thank you for your help! :slight_smile:

So I have a Ph.D. in analysis and that makes me not the best person to ask because I’ll either say that you need to do it yourself or get out pen and paper to do it with you (I also made a 5+ hour video course on autograd, which I won’t advertise here, though I need to update the session on forward differentiation with the progress that has been made, and you might prefer the free material out there).

If you wanted to do it yourself, you could write down the chain rule, for three functions, say (grossly oversimplifying)
f : Parameters → Intermediates, g: Intermediates → Outputs, h: Outputs → Loss.
where Parameters, Intermediates, Outputs are in some n, m, k-dimensional space respectively and the Loss is scalar.

The you spell out the chain rule (writing Df=dIntemediate/dParameter for the m x n matrix with the partial derivatives ∂Intermediate_i/∂Parameter_j) for various things, e.g.

dLoss/dParameter = dLoss/dLoss dLoss/dOutput dOutput/dIntermediate dIntermediate/dParameter

and then you start evaluating this from the left to right in the backward pass (dLoss/dLoss = 1). The intermediate results will be “row vectors” (grad_out and grad_in in autograd-lingo) and the factors are the Jacobians that are applied in the “backward”.
In recent PyTorch, you can call the .grad_fn nodes, so you can indeed compute this step by step for your viewing pleasure (but you have to sumtosize whenever broadcasting happens).

When you do forward AD, you start with a vector (the direction) on the right hand side of the chain rule.

The idea is not to get matrix-valued intermediates (as these cost a lot of memory and compute to process).

Best regards

Thomas

1 Like