Understanding chain rule for matrices

Suppose the gradient of a scalar loss with respect to a matrix A has shape (M, N), which is just the dimension of matrix A. Now suppose we want to compute the derivative of the same scalar loss with respect to B, which is defined as A=f(B) for some function f and has shape (M’, N’).

Then,

  • dA/dB has shape (M, N, M', N') in PyTorch (if you try using torch.autograd.functional.jacobian, this is the shape that you get)
  • dL/dA has shape (M, N) in PyTorch

We clearly know that dL/dB has shape (M', N'), but dL/dA @ dA/dB is invalid in PyTorch (but valid mathematically as chain rule).

How should I understand this issue? How should I compute dL/dB given dL/dA and dA/dB?

Example:

torch.matmul(torch.randn(4, 5), torch.randn(4, 5, 6, 7)
# RuntimeError: mat1 and mat2 shapes cannot be multiplied (140x6 and 5x4)