Alternative to einsum for 3D tensor product?


I am doing temporal resampling of multivariate time series using a tensor product.

The goal is to transform an input signal which has C channels and Ti time stamps to an output signal with the same number of channels and To time stamps. Each channel is interpolated independently and the output time stamps are a simple linear combination of the input ones.

That is, the time stamp i of an output channel c is
$y_c[i] = \sum_{l} a_c[i,l]x_c[l]$
where x_c is the is the corresponding input channel.

This is of course just a matrix multiplication
$y = x A$
where A is the matrix containing the a_c[i,l] coefficients and y and x are vectors.

To do this with minibatches, a need a tensor product instead of a matrix
product, since x will be a tensor with shape (batch,channels,in_dates)
and A will have shape (batch,in_dates,out_dates) and I want the output
to have shape (batch, channels, out_dates).

I have implemented this as follows (b is for Batch, c for Channels, i for Input time stamps and o for Output time stamps):
y = torch.einsum(‘bci,bio->bco’, x, A)

The problem I am facing is that this is very slow. I guess that building the operation from a string does not allow any optimization and I was wondering if there is a way to implement this using other faster operations. Maybe there is some reshaping, (un)squeezing and broadcasting black magic, but I can’t figure it out!

Thank you for your help.

Hi Jordi!

First let me note that in my (limited) experience einsum() is not
particularly slow. When I’ve timed it on some toy problems it’s
been more-or-less as fast as alternatives based on reshaping the
tensors being multiplied (and sometimes faster). Having said that,
I do recall seeing some github issues that suggest that einsum()
might better at optimizing some index-contractions than others.

Yes, you can do this with a couple of unsqueeze()s to line things up
right and then use matmul() with its “batch-dimensions” broadcasting
to broadcast the absent channel dimension of A:

>>> import torch
>>> torch.__version__
>>> _ = torch.manual_seed (2021)
>>> nBatch = 2
>>> nChannel = 2
>>> nInput = 3
>>> nOutput = 5
>>> x = torch.randn ((nBatch, nChannel, nInput))
>>> A = torch.randn ((nBatch, nInput, nOutput))
>>> torch.matmul (x.unsqueeze (dim = 2), A.unsqueeze (dim = 1)).squeeze()
tensor([[[ 3.2144,  0.5790,  0.3564, -0.5355,  1.5428],
         [-2.4341, -1.2235, -0.5440,  0.9202, -0.3586]],

        [[-0.3876, -3.7793,  1.8061,  3.0808,  1.4999],
         [-0.6005, -0.9385,  5.1433,  0.8513,  0.3969]]])
>>> torch.einsum ('bci, bio -> bco', x, A)
tensor([[[ 3.2144,  0.5790,  0.3564, -0.5355,  1.5428],
         [-2.4341, -1.2235, -0.5440,  0.9202, -0.3586]],

        [[-0.3876, -3.7793,  1.8061,  3.0808,  1.4999],
         [-0.6005, -0.9385,  5.1433,  0.8513,  0.3969]]])


K. Frank

Thanks for your help!

Yes, I did some research before posting and saw these issues and some posts on this forum too. Maybe I could reorder the dimensions and time different options.

Your answer is the kind of approach I was looking for. Actually, I already have some reshaping earlier in the code because channels and time are originally interleaved in the data source, so I wonder if chaining all these operations might impact performace.

I will time it for my particular case and see if this speeds things up.

Thanks again!