How to multiply a 4D tensor with 3D tensor according to code below?
qr = torch.einsum('b h i d, i d j → b h d j ', q, r_q)
b = 64
h = 8
i = 8
d = 64
j = 64
Hi @Jayden9912,
When you say “multiply” do you mean matrix-multiplication? Or do you want to use element-wise multiplication instead?
Your code example works and returns a Tensor of shape [b,h,d,j]
import torch
b=64
h=8
i=8
d=64
j=64
q=torch.randn(b,h,i,d)
r_q=torch.randn(i,d,j)
qr=torch.einsum("bhid, idj->bhdj", q, r_q)
print(qr.shape) #returns torch.Size([64, 8, 64, 64])
I assume you want to multiply a batch of Tensors (the 4D tensor) with a single 3D tensor for all tensors in your batch?
The einsum expression ("bhid, idj -> bhdj"
) you have is equivalent to taking a matmul over the i-th index, then taking the trace over the d-th index. There some examples of einsum notation in numpy’s documentation (see here)