Since the output is a tensor having m ** k
elements, you actually won’t be able to make it work for any m
and k
values, you’ll allocate a lot of memory even for small values of m
and k
(for instance, 37 Gb for m = k = 10
if your tensor uses 32 bit floating points).
This being said, you can analyze your use case and see how big can k
be. Lets say if k < 26
, then you can create a function that generates the arguments for your torch.einsum
call, using all the lower case letters (which happen to be the only ones supported by PyTorch for the moment).
A side note about torch.einsum
compared to numpy.einsum
is that numpy supports lists of indices too, making it possible to handle Einstein Summation for any number of tensors easily. Here is an illustration:
import numpy as np
x = np.random.rand(6).reshape(2, 3)
y = np.random.rand(6).reshape(2, 3)
z = np.random.rand(6).reshape(2, 3)
w = np.random.rand(6).reshape(2, 3)
r1 = np.einsum(x, [0, 1], y, [0, 2], z, [0, 3], w, [0, 4], [1, 2, 3, 4]) # Not currently supported by PyTorch
r2 = np.einsum('ni,nj,nk,nl -> ijkl', x, y, z, w) # Supported by PyTorch
print((r1 == r2).all()) # Prints True