Hi Neel!
The short story: I speculate that pytorch follows a Linear.weight
“transpose”
convention for improved performance on “low-end” hardware.
I’m not sure what you mean by “right facing,” but I comment below on what
I believe are the likely roots of pytorch’s choice of Linear
implementation.
As an aside, I believe this should be residual_stream @ W_in.T
. Based
on your shapes (and your einsum()
expression), your version (without the
transpose) will fail because the second dimension of residual_stream
doesn’t match the first dimension of W_in
.
As for the weight
of Linear
, I believe that pytorch chooses to have the
weight
multiply the input
to Linear
from the right so that Linear
can
accept input
s with any number of leading “batch” dimensions, e.g., no
batch dimension, a single batch dimension, or a batch and a channel
dimension.
Thus shapes [10]
, [5, 10]
, and [5, 3, 10]
would all be valid inputs to
Linear (in_features = 10, out_features = 2)
.
However, torch.nn.Linear (in_features = 10, out_features = 2)
has a weight
of shape [2, 10]
and you can see by inspection that Linear
multiplies its input
from the right by the transpose of its weight
, that is, e.g.,
torch.randn (5, 10) @ torch.nn.Linear (10, 2).weight.T
.
Why is the matrix that is stored as weight
the transpose of the of the matrix
that is used in the actual matrix-matrix multiplication (or, more generally,
the tensor-matrix multiplication)?
Although there have been historical exceptions, row-major is the de facto
standard for matrix storage*. That is to say that as you move sequentially
though the (linear) address space in which a matrix is stored, the column
index varies the most rapidly. Equivalently, individual rows are stored
sequentially (and then chained together one after the other).
It is therefore true that performing matrix-matrix-transpose multiplication
has greater memory locality than does matrix-matrix multiplication because
as you perform the row-column vector-vector dot products of which matrix
multiplication is composed, moving sequentially through a column of the
transposed matrix translates to moving sequentially through the corresponding
row of the untransposed matrix, and therefore the data in transposed matrix’s
column is accessed sequentially in memory.
Whether this memory locality affects performance is highly dependent on
the details of the hardware – size and page size of cache memory, relative
cost of cache hits vs. misses, structure and number of floating-point pipelines,
cpu vs. gpu, and so on. On “small” systems where a row of a “large” matrix
might fit into cache memory, but the whole matrix doesn’t, matrix-matrix
multiplication could be dramatically slower than matrix-matrix-transpose
multiplication, and significant algorithmic effort was made in deciding whether
to store a matrix or its transpose (or even sometimes both).
With modern gpus, this is probably much less likely to matter. I’ll also note
that gpus (as well as most modern cpu pipelines) are optimized to be able
to stride efficiently through memory (that is, with strides greater than one,
rather than purely sequentially), which is what you are doing when you access
the column of a matrix (that is stored in row-major form).
Having said all of that, here are some real-life timings for matrix-matrix vs.
matrix-matrix-transpose multiplication on one particular system.
Here is the test script:
import torch
print (torch.__version__)
print (torch.version.cuda)
print (torch.cuda.get_device_name())
import math
import statistics
import time
_ = torch.manual_seed (2022)
device = 'cpu'
print ('device =', device)
nWarm = 100
nStat = 10
nTime = 100
print ('nStat =', nStat, ', nTime =', nTime)
b = 100
m = 10000
n = 10000
print ('b =', b, ', m =', m, ', n =', n)
a = torch.randn (b, m, device = device)
mr = torch.randn (m, n, device = device)
ml = torch.randn (n, m, device = device)
# warm up
for i in range (nWarm):
p = a @ mr
p = a @ ml.T
mr_timing = []
mlT_timing = []
for i in range (nStat):
torch.cuda.synchronize()
tStart = time.time()
for i in range (nTime):
p = a @ mr
torch.cuda.synchronize()
tEnd = time.time()
mr_timing.append (tEnd - tStart)
torch.cuda.synchronize()
tStart = time.time()
for i in range (nTime):
p = a @ ml.T
torch.cuda.synchronize()
tEnd = time.time()
mlT_timing.append (tEnd - tStart)
mr_mean = statistics.mean (mr_timing)
mr_stderr = statistics.stdev (mr_timing) / math.sqrt (nStat - 1)
mlT_mean = statistics.mean (mlT_timing)
mlT_stderr = statistics.stdev (mlT_timing) / math.sqrt (nStat - 1)
print ('a @ mr timings:', mr_mean, '+-', mr_stderr)
print ('a @ ml.T timing:', mlT_mean, '+-', mlT_stderr)
Here is its cpu-version output:
1.12.0
11.6
GeForce GTX 1050 Ti
device = cpu
nStat = 10 , nTime = 100
b = 100 , m = 10000 , n = 10000
a @ mr timings: 12.70658016204834 +- 0.007860937265143848
a @ ml.T timing: 12.256337642669678 +- 0.007302343186920837
And here is its companion gpu-version output:
1.12.0
11.6
GeForce GTX 1050 Ti
device = cuda
nStat = 10 , nTime = 100
b = 100 , m = 10000 , n = 10000
a @ mr timings: 1.354783058166504 +- 0.0005118169535293032
a @ ml.T timing: 1.3442912817001342 +- 0.0014081416198141895
For this particular system, the performance differences are quite small, and
not necessarily statistically significant. (If significant, they could perhaps
be meaningful at the margins – maybe a percent or so, but certainly not
something like a factor of two).
It’s interesting to note – but I wouldn’t read too much into it – that
matrix-matrix-transpose multiplication outperforms matrix-matrix multiplication
on both the cpu and gpu by a little bit. (Of course, not surprisingly, the gpu
outperforms the cpu by a factor of ten – which one would hope because this
is precisely the kind of thing that should be fast on a gpu.)
Again, any such differences will depend on the hardware, and will likely also
depend on the size and shape of the tensors involved.
So, why does pytorch choose to store its multiplying-from-the-right weight
matrix as its transpose?
It could just be historical habit from the time when the matrix-transpose
optimization almost always mattered for large matrices.
But it could be that on low-end hardware (think inference on an embedded
system) the matrix-transpose optimization really does pay off, while on
“fancy” hardware it doesn’t really matter, one way or the other.
*) Note, it is possible to have a Linear
store its weight
in column-major
form by setting its weight
to the transpose view of a row-major matrix:
lin = torch.nn.Linear (10, 2, bias = False)
lin.weight.shape
wt_column_major = torch.randn (10, 2)
wt_column_major.shape
lin.weight = torch.nn.Parameter (wt_column_major.T)
lin.weight.shape
Best.
K. Frank