TL;DR: I don’t know how I should imagine input to be processed, if input is a single vector or a sequence of vectors. Because the multiplications it does is different for either case.
import torch
import torch.nn as nn
lin = nn.Linear(6, 10, True)
W = lin.weight #BUT, W will be a 10 x 6 matrix!
bias = lin.bias
I want to manually compute what nn.Linear
does to understand what is happening.
In PyTorch, vectors are ROW vectors. In Math, people write it as a COLUMN vectors. So funny stuff happens when you try to code a mathematic formulation.
x = torch.randn(6) # 6 dimensional ROW vector.
out1 = lin(x) # 10 dimensional ROW vector.
If we want to get this manually, we do:
out2 = torch.matmul(W, x) + bias
But hold on! This computation doesn’t make sense because W is a 10 x 6 matrix and x is a 1 x 6 matrix. But because it’s just a tensor with dimension 6, PyTorch is able to treat it like a 6 x 1 matrix (COLUMN vector). That’s what the docs says. It doesn’t work if you use torch.mm
and I will not use @
because that transposes under the hood for you.
So let’s try again but with a differnt x.
x = torch.rand(2, 6) # Two 6-dimensional row vectors.
out1 = lin(x) # Works
out2 = torch.matmul(x, W.T) + bias # Now all the sudden we have to care again about the dimensions
So it’s like, I struggle to understand how I have to imagine how input is being processed. If I model it, can I consider input as single vector? But when it trains, you can pass in a tensor of vectors and it either sequentially computes it for each individual vector like before or does it do fancy transpose to pull that off? I really hate that I have no idea what it does.
According to the doc, it does: y = x * A^T
However, it’s unclear what A is. Is it `lin.weight’ or does it have the same dimensions I gave to it?
To connect this back to nn.Embedding
embed == nn.Embedding(8, 6)
E = embed.weight # THIS IS A 8 x 6 Matrix,
So the transposing stuff doesn’t happen. At least I have somewhat more control here, because the only thing nn.Embedding
does is generate a random matrix and allows me to access it’s rows by passing in a tensor of indices. Mathematically, that was equivalent to multiplying that Matrix with hot vectors (all entries except one zeroes, one entry 1) but if a computer can access it by index anyway, it does seem okay to omit that multiplication. Still a bit annoying that for one it transposes it and for the other it does not.