Forward Mode AD with multiple tangents for each primal

Hello,

I am currently exploring ways to compute Jacobian-vector products (JVP) in PyTorch. I recently came across the following resources:

However, I am still relatively new to auto-differentiation and find it challenging to fully understand how forward-mode AD operates and integrates with the forward pass.

Questions:

  1. Compared to computing VJP (vector-Jacobian product) with grad in the backward pass, how does forward-mode AD differ in terms of:

    • Speed of computation?
    • Memory usage?
  2. My main goal is to compute JVPs efficiently for multiple tangents for each primal i.e, each primal is associated with multiple tangents. Is there a way to compute JVPs for multiple tangents in parallel without having to recompute f(primal) for each tangent vector?

Any insights or examples would be greatly appreciated!

Thank you!

  1. Speed of computation should be similar. Memory usage for forward AD is lower since you no longer need to save a bunch of activations for backward.

  2. You likely want to do vectorized jvps. This is possible to do using torch.func by combining torch.func.vmap — PyTorch 2.5 documentation with torch.func.jvp — PyTorch 2.5 documentation.