It sounds like a normal matrix multiplication?
M, N, L = 3, 4, 5
input = torch.rand(M, N).float()
weight = nn.Parameter(torch.rand(L, M).float(), True)
result = torch.mm(weight, input).t()
print(result.shape)
It sounds like a normal matrix multiplication?
M, N, L = 3, 4, 5
input = torch.rand(M, N).float()
weight = nn.Parameter(torch.rand(L, M).float(), True)
result = torch.mm(weight, input).t()
print(result.shape)