Torch.matmul problem

I have two tensor torch.randn(3, 10, 4) and torch.randn(4, 10, 3)
I want to multiply them to become a (10,10) tensor

However, it has error when I use matmul, may I know which part is wrong?
tensor1 = torch.randn(3, 10, 4)
tensor2 = torch.randn(4, 10, 3)
print(torch.matmul(tensor1, tensor2).size())

Hi Hcleung!

You may use the Swiss-army knife of generalized tensor-multiplication
operators, torch.einsum(), to “contact” the desired indices:

>>> import torch
>>> torch.__version__
'1.9.0'
>>> tensor1 = torch.randn(3, 10, 4)
>>> tensor2 = torch.randn(4, 10, 3)
>>> torch.einsum ('kil,ljk -> ij', tensor1, tensor2).shape
torch.Size([10, 10])

However, you haven’t told us what sort of “multiplication” you want
to perform, so you should check that einsum() is doing what you
want and expect.

Best.

K. Frank