I am trying to us torch.mm to do the following matrix operation,
If matrix is a M * N tensor,
batch is a N * B tensor, how can i achieve,
In each batch, matrix @ batch_i, which gives M, and put the batch size together, the output tensor looks like M * B
There two questions here,
1.To use torch.bmm, it seems need both matrix need be batch, but my first input is not
2. The batch size need be the first dimension, while my batch size in the end
As in the PyTorch document,
torch.bmm should be used when both input are batched matrixes.
But it seems like you are doing multiplication with a batch of vector by a constant matrix.
If you really want to use bmm, you can make both of your inputs be the right format like the following.
M, N, B = 10, 12, 14
x = torch.rand(M, N)
y = torch.rand(N, B)
x_batch = x.repeat(B, 1, 1) # make x_batch be the shape of (B, M, N)
y_batch = y.T.unsqueeze(2) # make y_batch be the shape of (B, N, 1)
results = torch.bmm(x_batch, y_batch)