Computing batch Jacobian efficiently

I’m unsure what the most efficient implementation is if both my inputs and the outputs are batched. Specifically, if I have inputs of shape [B, n] and func maps to outputs of shape [B, m], then calling

jac = torch.autograd.functional.jacobian(func, inputs, vectorize=True)

returns a tensor of shape [B, n, B, m].

But if there is no interaction between batches, then jac[i, :, j, :] are just zero tensors, and I only really need to compute jac[i, :, :] = jac[i, :, i, :].

Of course, I can just select the appropriate entries of jac, but I’m wondering if this is still the most efficient approach here since a lot of unnecessary gradients are computed. Is there a better way?

Seems to be the same problem as in this thread.

3 Likes