How to multiply matrix (batch_size, n1) * (n1,n2,n3))?

I want a batch matrix to multiply a 3D matrix, where it works ok when mat2 is 2D

batch_size = 256
mat1 = torch.randn(batch_size,25)
mat2 = torch.randn(25,36)
torch.matmul(mat1, mat2).shape

However, when mat2 is 3D,

batch_size = 256
mat1 = torch.randn(batch_size,25)
mat2 = torch.randn(25,36, 36)
torch.matmul(mat1, mat2).shape

it causes the following errors: RuntimeError: mat1 and mat2 shapes cannot be multiplied.

1 Like

I suppose that you want to multiply mat1 with each 25*36 matrix in mat2 (36 matrices in total, i.e. mat2.shape[-1]), and finally get 36 results.
The following code may work :smile: :

import torch

b = 256
mat1 = torch.randn(b, 25)
mat2 = torch.randn(25, 36, 36)
mat2_p = mat2.permute(2, 0, 1)  # move the last dimension (matrix number) to the first
res = torch.matmul(mat1, mat2_p)
res = res.permute(1, 2, 0)  # restore

print(res.shape)

Thank you so much for your quick response @111414 .
I am sorry I did not describe it clearly.
I actually want the dim1 of mat1 as weights to integrate mat2 along dim0 of mat2, then I will get a matrix (batch, 36,36).
Like
mat1: batchxn1,
mat2: n1xn2xn3,
then mat1*mat2 → batchxn2xn3.

1 Like

I think torch.einsum — PyTorch 1.8.1 documentation works:

import torch
m1 = torch.randn(256, 25)
m2 = torch.randn(25, 36, 36)
res = torch.einsum('bk, kij->bij', m1, m2)
print(res.shape)
1 Like

Thanks a lot @111414 .
I also found that swapping the position between mat1 and mat2 also works:

batch_size = 256
mat1 = torch.randn(25,batch_size)
mat2 = torch.randn(36,36,25)
torch.matmul(mat2, mat1).shape
1 Like