Alternative to einsum

Hi all,

I try to do a capsule neural network (CapsNet) model from scratch to try to understand how it works

with this line of code it works

def input_caps2U(self, x):
return torch.einsum(‘bij,ijkl->bikl’, x, self.weight).contiguous()

but i would like to replace einsum, does anyone have an idea how to do this and can they please show me because no matter what i change i always get an error.

Thank you for your help.

Hi Py!

With appropriate permute()s and unsqueeze()s to get the dimensions
aligned properly, you can use (batch) matmul().

Consider:

>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> b = 3
>>> i = 4
>>> j = 5
>>> k = 2
>>> l = 2
>>>
>>> x = torch.randn (b, i, j)
>>> weight = torch.randn (i, j, k, l)
>>>
>>> resultA = torch.einsum ('bij, ijkl -> bikl', x, weight)
>>> xB = x.unsqueeze (2)
>>> weightB = weight.permute (2, 3, 0, 1).unsqueeze (-1).unsqueeze (2)
>>> resultB = (xB @ weightB).squeeze().permute (2, 3, 0, 1)
>>>
>>> torch.equal (resultA, resultB)
True
>>>
>>> resultA.shape
torch.Size([3, 4, 2, 2])

As an aside, einsum() works well. Why would you want to forgo the
convenience it offers?

Best.

K. Frank

1 Like

Hi Frank,

First, thanks a lot for your help.

I would like to replace it to understand a little more how einsum works and to compare its performance and to see if there are other ways to optimize memory or make the calculation faster. By the way, I wanted to ask you three questions:

1- Could you please explain to me the code you have written to understand how you have managed to achieve the lines you have suggested, which works very well because I’ve been trying to find an alternative to einsum for a week.

2- Is there a way of comparing their performance (memory, speed, etc.)?

3- How I can use matmul() or Matrix Multiplication on Polars python only to learn how einsum works.

Thank you for your help.

Hi Py!

Print out the shapes of the relevant tensors step by step after each
permute(), unsqueeze(), and squeeze() to see how the dimensions
move around to get aligned. Look at the documentation for matmul().
The key point is that if the two tensors both have more than two dimensions,
then “batch” matrix multiplication is performed where matrix multiplication
is performed on the last two dimensions, while the leading dimensions
are treated as batch dimensions and must be broadcastable. weight’s
unsqueeze (2) is there to add a singleton dimension that gets broadcast
against x’s b dimension.

Memory-profiling pytorch tends to be difficult, but I wouldn’t expect the
two approaches to differ significantly in memory. (However, broadcasting
can sometimes use memory unexpectedly and seemingly unnecessarily.)

You can perform timing tests using python’s time.time(). In earlier versions
of pytorch we have sometimes seen einsum() perform unexpectedly and
unnecessarily poorly in comparison to matmul() (and sometimes even in
comparison to loops), and sometimes seen it outperform matmul()
(presumably due to some performance bug in matmul()).

Best.

K. Frank

1 Like

Hi Frank,

Thanks a lot for the explanation, now I have a better understanding.

Best.

The point of einsum is not necessarily for speed. In theory, they should be relatively the same, for most circumstances. (Granted, there may be cases where it works faster or slower.)

The reason some people prefer einsum, whether it be in Tensorflow, PyTorch or NumPy, is that it provides a more concise and natural way to write and understand the tensor operations.

You can think of einsum as a type of code shorthand for tensor operations that is universal between all machine learning libraries. Once you learn it, the tensor ops become easier to visualize, check, and troubleshoot, regardless of which library you are working in.

1 Like

Hi J_Johnson,

Thanks a lot for these explanations which enabled me to learn a little more about einsum.