How to operate batch matrix multiplication

I have a batch of matrix A(A.shape=torch.Size([2, 3, 4])), and a matrix B(B.shape=torch.Size([4, 3])). In my opinion, I think A consists of two parts:A1 and A2.(A1.shape=torch.Size([3, 4], A2.shape=torch.Size([3, 4]).
How can I multy B and A1 and A2 apartly? My excepted result is (shape=torch.Size([2, 3, 3]).

batched matrix x broadcasted matrix can be calculated by matmul.

I am not sure if that A1, A2 understanding is correct lol
But for what you want to do: reshape A into [(2x3), 4], do matmul with B[4,3] to get [(2x3),3], do another reshape to have [2,3,3]

Can you tell me the details? Is my code corrct?


Yes, your right. The details is as follows:

>>> import torch
>>> a = torch.randn(2,3,4)
>>> b = torch.randn(4,3)
>>> torch.matmul(a,b).shape
torch.Size([2, 3, 3])