How does Forward-mode AD work behind the scenes in Pytorch?

I am following this tutorial on Forward-mode AD, particularly the “Usage with Modules” part. And I have looked into the torch.autograd.forward_ad file as well.

I was curious on how exactly the forward computations are happening. As in, how and where are the intermediate jvp’s getting computed? Is it something related to the implementations of @staticmethod def jvp(ctx, gI):? If so, where can I find those definitions for the nn.module’s of an arbitrary model?

Derivatives for forward mode (and reverse mode) are actually implemented at the operator level rather than the module level. You can find them both here pytorch/tools/autograd/derivatives.yaml at main · pytorch/pytorch · GitHub

When writing a custom autograd Function, @staticmethod def jvp(ctx, gI) is what you can use to define the jvp rule.