I’m trying to a single matrix A of shape (80, 256) with a batch of other matrices B of shape (16, 256, 65).
In numpy one can call np.dot(A, B) to perform the operation.
How can one achieve this in pytorch?
torch.mm or torch.bmm requires the batch sizes to be the same
torch.bmm(A.expand(16, 80, 256), B)
Yes, torch.bmm doesn’t support broadcasting right now but the above should work.
edit: you should also be able to use torch.matmul(A, B)
.
1 Like