I have an input x of dimension 1x 2, batch size = 128, ⦠Hence input gets passed in batches as 128x2. I have 2 parameters L_p(dimension = 1 x 1) and R_p (dimension = 2 x 2). The operation x @ R_p
works but the operation L_p @ x
is throwing error that matrix dimensions are not satisfied since x is passed as 128 x 2. But actually x is 1x2 , so how do I make it work? Please help!
x @ R_p
will apply a matrix multiplication with the shapes:
[128, 2] @ [2, 2] = [128, 2]
which are the expected shapes.
However, L_p @ x
tries to execute:
[1, 1] @ [128, 2]
which is invalid for a matmul.
Iām not sure what the expected output shape is, but assuming that you would like to broadcast L_p
in the batch dimension, you could use:
out = L_p.expand(x.size(0), 1, -1) @ x.unsqueeze(1)
which would create an output of [128, 1, 2]
.
1 Like