Hi, I am trying to figure how the tensor gradient is calculated when using the einsum operator. Can someone point me towards documentation for this or alteast the internal implementation in source code?
Good question, actually one of the exercises in my advanced autograd class.
The answer lies in the autograd graph:
x = torch.randn(5,5, requires_grad=True) y = torch.randn(5,5, requires_grad=True) a = torch.einsum('ik,kj->ij', x, y) a
will show you that the
a it is some
view backward. And indeed, the graph looks like this:
einsum reduces to reshaping operations and batch matrix multiplication in
Like when you write your own computation in Python, PyTorch actually keeps track of the calls to “explicitly differentiable functions” and then computes the backward piece by piece.
The reduction is actually the other way round as in NumPy where
bmm reduces to
If you look at the PyTorch
einsum implementation you see that (save some optimizations which would have been better put in sumproduct_pair IMHO) it rearranges some things and then calls
sumproduct_pair. That in turn could have a very simple explicit derivative with expand + sumproduct_pair.