Replace "for loop iteration" by Torch matrix multiplication

Hi, I’m not good at using trick for matrix multiplication.

I want to get rid of “for loop iteration” by using Pytorch functions in my code. But the formula is complicated and I can’t find a clue. Can the “for loop iteration” in the below replaced with the Torch operation?

import torch ,sys
B=10
L=20
H=5
mat_A=torch.randn(B,L,L,H)
mat_B=torch.randn(L,B,B,H)
tmp_B=torch.zeros_like(mat_B)
for x in range(L):
   for y in range(B):
       for z in range(B):
           tmp_B[:,y,z,:]+=mat_B[x,y,z,:]*mat_A[z,x,:,:]

I’d like to know if there are any tips on how to solve this kind of problems…

Hi Hyeonuk!

Yes, my tip is to use pytorch’s “Swiss-army knife” of tensor multiplication,
einsum().

>>> import torch
>>> torch.__version__
'1.11.0'
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> B=10
>>> L=20
>>> H=5
>>>
>>> mat_A=torch.randn(B,L,L,H)
>>> mat_B=torch.randn(L,B,B,H)
>>> tmp_B=torch.zeros_like(mat_B)
>>> for x in range(L):
...    for y in range(B):
...        for z in range(B):
...            tmp_B[:,y,z,:]+=mat_B[x,y,z,:]*mat_A[z,x,:,:]
...
>>> C = torch.einsum ('xyzh,zxlh -> lyzh', mat_B, mat_A)
>>>
>>> torch.allclose (C, tmp_B, atol = 1.e-5)
True

Best.

K. Frank

1 Like

Amazing, I have to practice for using einsum! Thanks