Tricky einsum operation going wrong

Hi folks,

Consider the matrices A and B such that the shape of A is torch.Size([4096, 25088, 8]) and the shape of B is torch.Size([256, 25088]). The final result I want to obtain is the matrix C resulted from the following einsum operation

C = torch.einsum('bij,ki->bkj', A, B)

However, to get C, I want to compute sections of A and B, operate them using einsum and then accumulating over the number of sections using torch.sum. Now, consider A_sec with shape torch.Size([4096, 196, 128, 8]) and B_sec with shape torch.Size([256, 196, 128]). In this case, I have 196 sections of 128 vectors. The same result is obtained doing

C = torch.einsum('bijk,lij->blk', A_sec, B_sec)

But, as mentioned before, I want to accumulate on the number of sections, something like

# THIS PIECE OF CODE IS WRONG!
C_sec = torch.einsum('bijk,lij->bilk', A_sec, B_sec)
C = torch.sum(C_sec, axis=1)

But this answer doesn’t provide the C matrix as before. What is wrong with what I did? What operations do I need to perform using einsum to get the result I want?

Thanks

Hi Matheus!

Your two versions of einsum() work for me:

>>> import torch
>>> print (torch.__version__)
2.0.1
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> nb = 4096
>>> ni = 196
>>> nj = 128
>>> nk = 8
>>> nl = 256
>>>
>>> A_sec = torch.randn (nb, ni, nj, nk)
>>> B_sec = torch.randn (nl, ni, nj)
>>>
>>> A_sec.shape
torch.Size([4096, 196, 128, 8])
>>> B_sec.shape
torch.Size([256, 196, 128])
>>>
>>> C1 = torch.einsum ('bijk, lij -> blk', A_sec, B_sec)
>>> C_sec = torch.einsum ('bijk, lij -> bilk', A_sec, B_sec)
>>> C2 = torch.sum (C_sec, axis = 1)
>>>
>>> torch.allclose (C1, C2, atol = 1.e-3)
True

Best.

K. Frank

1 Like

Hi Frank!

This is very interesting, the way I was comparing C1 and C2 in your code was by writing

(C1 == C2).all()

In fact, when I do this in your code, it prints False even though it works with tol 1.e-3. I tried to 1.e-4, and it also starts to be False. In my previous examples, C1 and C2 were equal in (C1 == C2).all(). Do you know any reason why this is happening?

Hi Matheus!

This tests for exact equality, which you can’t, in general, expect in the presence
of floating-point round-off error. (The two versions sum their sums in different
orders, so the results can differ by round-off error.)

The tensors being multiplied are quite large, so individual elements of the result
range up to about 1.e3 in magnitude. An absolute tolerance (atol) of 1.e-3
then corresponds to about six decimal digits of agreement between the two results,
which is consistent with the six-to-seven digits you expect from single-precision
(float) arithmetic.

Best.

K. Frank

1 Like

Thanks a lot for the explanation! Really appreciate that :slight_smile: