If I have a function like
f = lambda x: 2 * x
and I want to calculate the VJP when it is applied to a matrix, I can do
point = torch.tensor([[1.,2.,3.], [4.,5.,6.]])
v = torch.tensor([[1.,2.,3.], [4.,5.,6.]])
vjp(f, point, v)[1]
# tensor([[ 2., 4., 6.],
# [ 8., 10., 12.]])
But what would be the equivalent product involving the size (2, 3, 2, 3) Jacobian calculated from
jac = jacobian(f, torch.tensor([[1.,2.,3.], [4.,5.,6.]]))
Because various ways of multiplying the Jacobian and vector do not give the same result as VJP
# all return rank-4 tensors
jac @ v.T
v @ jac.T
v.T @ jac
jac.T @ v