I have a function
f : R -> R^d,
d >> 1, for which I would like to compute a batched-derivative. Since
d is large, I assume that ForwardAD is better suited than BackwardAD.
My attempt is the following:
tangent = torch.ones_like(x_in).to('cuda') with fwAD.dual_level(): dual_input = fwAD.make_dual(x_in, tangent) dual_output = model(dual_input) y, jvp_fwAD = fwAD.unpack_dual(dual_output)
However, it seems that this approach is not significantly faster/memory efficient than simply using
v = torch.ones_like(t_in).to('cuda') y, jvp_jvp = jvp(model, x_in, v)
which does two backward passes?
Any idea what’s the best way to take a derivative w.r.t. such a function?