Multiplication of 4D and 3D tensor using torch.einsum

How to multiply a 4D tensor with 3D tensor according to code below?
qr = torch.einsum('b h i d, i d j → b h d j ', q, r_q)
b = 64
h = 8
i = 8
d = 64
j = 64

Hi @Jayden9912,

When you say “multiply” do you mean matrix-multiplication? Or do you want to use element-wise multiplication instead?

Your code example works and returns a Tensor of shape [b,h,d,j]

import torch
b=64
h=8
i=8
d=64
j=64
 
q=torch.randn(b,h,i,d)
r_q=torch.randn(i,d,j)

qr=torch.einsum("bhid, idj->bhdj", q, r_q)
print(qr.shape) #returns torch.Size([64, 8, 64, 64])

I assume you want to multiply a batch of Tensors (the 4D tensor) with a single 3D tensor for all tensors in your batch?

The einsum expression ("bhid, idj -> bhdj") you have is equivalent to taking a matmul over the i-th index, then taking the trace over the d-th index. There some examples of einsum notation in numpy’s documentation (see here)