Column-wise matrix multiplication

It sounds like a normal matrix multiplication?

M, N, L = 3, 4, 5
input = torch.rand(M, N).float()
weight = nn.Parameter(torch.rand(L, M).float(), True)
result = torch.mm(weight, input).t()
print(result.shape)
1 Like