# Alternative to einsum for 3D tensor product?

Hi,

I am doing temporal resampling of multivariate time series using a tensor product.

The goal is to transform an input signal which has C channels and Ti time stamps to an output signal with the same number of channels and To time stamps. Each channel is interpolated independently and the output time stamps are a simple linear combination of the input ones.

That is, the time stamp i of an output channel c is
$y_c[i] = \sum_{l} a_c[i,l]x_c[l]$
where x_c is the is the corresponding input channel.

This is of course just a matrix multiplication
$y = x A$
where A is the matrix containing the a_c[i,l] coefficients and y and x are vectors.

To do this with minibatches, a need a tensor product instead of a matrix
product, since x will be a tensor with shape (batch,channels,in_dates)
and A will have shape (batch,in_dates,out_dates) and I want the output
to have shape (batch, channels, out_dates).

I have implemented this as follows (b is for Batch, c for Channels, i for Input time stamps and o for Output time stamps):
y = torch.einsum(‘bci,bio->bco’, x, A)

The problem I am facing is that this is very slow. I guess that building the operation from a string does not allow any optimization and I was wondering if there is a way to implement this using other faster operations. Maybe there is some reshaping, (un)squeezing and broadcasting black magic, but I can’t figure it out!

Hi Jordi!

First let me note that in my (limited) experience einsum() is not
particularly slow. When I’ve timed it on some toy problems it’s
been more-or-less as fast as alternatives based on reshaping the
tensors being multiplied (and sometimes faster). Having said that,
I do recall seeing some github issues that suggest that einsum()
might better at optimizing some index-contractions than others.

Yes, you can do this with a couple of unsqueeze()s to line things up
right and then use matmul() with its “batch-dimensions” broadcasting
to broadcast the absent channel dimension of A:

>>> import torch
>>> torch.__version__
'1.7.1'
>>> _ = torch.manual_seed (2021)
>>> nBatch = 2
>>> nChannel = 2
>>> nInput = 3
>>> nOutput = 5
>>> x = torch.randn ((nBatch, nChannel, nInput))
>>> A = torch.randn ((nBatch, nInput, nOutput))
>>> torch.matmul (x.unsqueeze (dim = 2), A.unsqueeze (dim = 1)).squeeze()
tensor([[[ 3.2144,  0.5790,  0.3564, -0.5355,  1.5428],
[-2.4341, -1.2235, -0.5440,  0.9202, -0.3586]],

[[-0.3876, -3.7793,  1.8061,  3.0808,  1.4999],
[-0.6005, -0.9385,  5.1433,  0.8513,  0.3969]]])
>>> torch.einsum ('bci, bio -> bco', x, A)
tensor([[[ 3.2144,  0.5790,  0.3564, -0.5355,  1.5428],
[-2.4341, -1.2235, -0.5440,  0.9202, -0.3586]],

[[-0.3876, -3.7793,  1.8061,  3.0808,  1.4999],
[-0.6005, -0.9385,  5.1433,  0.8513,  0.3969]]])


Best.

K. Frank

Hi,