When I read the document the description of nn.Linear() is implemented as y = xA^T + b. However, I need y=Wx+b. Is there anyway to get y=Wx+b with nn.Linear()?
nn.Linear
does exactly y=Wx+b
.
In the documentation, A
is a matrix and A^T
denotes A-transpose.
This is due to how the data is passed through the network.
nn.Linear
accepts inputs of size (B, D_in)
and transforms it into output of size (B, D_out)
by a linear transformation using the weight matrix of size (D_out, D_in)
.
Thanks for the answer. Actually I already solve this problem. Previously I had a misunderstanding of this implementation. I only treat my input vector as a column vector, so I need a left multiply weight matrix. However, I can use a row vector instead so right multiply a matrix then I can still get a new row vector.
m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)
- W: (30,20), x should transpose to(20,128) so Wx = (30,20)*(20,128)=(30,128)
- xW^T = (128,20)*(20,30)=(128,30)
Wx and xW^T are the same result, but Wx is by column and xW^T is by row? Thanks.