Einstein summation in pytorch

I need to implement module which will essentially is a 2d grid of independed torch.Linear instances. Is there any elegant pytorch way to write the following code:

out = np.zeros(w.shape[1:])
for c in xrange(out.shape[0]):
    for i in xrange(out.shape[1]):
        for j in xrange(out.shape[2]):
            out[c, i, j] = x[:, i, j].dot(w[:, c, i, j])

Also know as Einstein summation in numpy:

out = np.einsum('kij,kmij->mij', x, w)

An obvious solution here is just reshape it all and use a huge Linear layer, which is probably good enough for starters.

you could use torch.bmm or torch.baddbmm here i think, which does batch matrix multiply.

  • view your 2D grid as 1D list of matrix-multiplies to do
  • do Batch matrix multiply
  • unview back to 2d grid

That’s even better, thanks. I overlooked it for some reason. That should do the trick better.