Matrix-Matrix multiply different batch sizes

I’m trying to a single matrix A of shape (80, 256) with a batch of other matrices B of shape (16, 256, 65).
In numpy one can call, B) to perform the operation.
How can one achieve this in pytorch? or torch.bmm requires the batch sizes to be the same

torch.bmm(A.expand(16, 80, 256), B)

Yes, torch.bmm doesn’t support broadcasting right now but the above should work.

edit: you should also be able to use torch.matmul(A, B).

1 Like

Thank you @richard, torch.matmul(A, B) does, indeed, do the trick!