Trouble flexibly implementing a tensor manipulation

Hello,

I’m trying to figure out how to do the following operation in general (and if it has a name). I am given a list of tensors [X_1,X_2,...,X_k] all of which are of shape n by m. I want to do the following operation:

T_{i_1,...,i_k} = \sum_m [X_1]_{m,i_1}*[X_2]_{m,i_2} *...* [X_k]_{m,i_k}.

I can do this easily for a fixed number of tensors by doing:

torch.einsum('ni,nj,nk,nl -> ijkl', [X,Y,Z,W])

However, this is not satisfactory for unknown length lists. Any thoughts on how to implement this well?

Since the output is a tensor having m ** k elements, you actually won’t be able to make it work for any m and k values, you’ll allocate a lot of memory even for small values of m and k (for instance, 37 Gb for m = k = 10 if your tensor uses 32 bit floating points).

This being said, you can analyze your use case and see how big can k be. Lets say if k < 26, then you can create a function that generates the arguments for your torch.einsum call, using all the lower case letters (which happen to be the only ones supported by PyTorch for the moment).

A side note about torch.einsum compared to numpy.einsum is that numpy supports lists of indices too, making it possible to handle Einstein Summation for any number of tensors easily. Here is an illustration:

import numpy as np

x = np.random.rand(6).reshape(2, 3)
y = np.random.rand(6).reshape(2, 3)
z = np.random.rand(6).reshape(2, 3)
w = np.random.rand(6).reshape(2, 3)
r1 = np.einsum(x, [0, 1], y, [0, 2], z, [0, 3], w, [0, 4], [1, 2, 3, 4]) # Not currently supported by PyTorch
r2 = np.einsum('ni,nj,nk,nl -> ijkl', x, y, z, w) # Supported by PyTorch
print((r1 == r2).all())  # Prints True
1 Like

Thanks! A common m will be 2, so going up to 20 or so is not implausible. I am also aware of the fact I could do this easily in numpy, but then I’ll be without all the niceties (automatic differentiation, pre-made optimizers, etc) that make torch friendly.

I was indeed leaning this way (at the end of the day it is still only 26 lines of code), but I felt that there needed to be a better way to do this!

I actually opened an issue in PyTorch’s GitHub about that subject. It looks like a PR implementing the support of sublist argument (i.e. like numpy) is welcome.

1 Like

Thanks! I posted there with some additional context!

1 Like