I’m unsure what the most efficient implementation is if both my inputs and the outputs are batched. Specifically, if I have inputs
of shape [B, n]
and func
maps to outputs of shape [B, m]
, then calling
jac = torch.autograd.functional.jacobian(func, inputs, vectorize=True)
returns a tensor of shape [B, n, B, m]
.
But if there is no interaction between batches, then jac[i, :, j, :]
are just zero tensors, and I only really need to compute jac[i, :, :] = jac[i, :, i, :]
.
Of course, I can just select the appropriate entries of jac
, but I’m wondering if this is still the most efficient approach here since a lot of unnecessary gradients are computed. Is there a better way?
Seems to be the same problem as in this thread.