I am following this tutorial on Forward-mode AD, particularly the “Usage with Modules” part. And I have looked into the torch.autograd.forward_ad
file as well.
I was curious on how exactly the forward computations are happening. As in, how and where are the intermediate jvp’s getting computed? Is it something related to the implementations of @staticmethod def jvp(ctx, gI):
? If so, where can I find those definitions for the nn.module
’s of an arbitrary model?