Performance of Right-Facing vs Left-Facing matmuls

I’m writing an implementation of a transformer to pre-train from scratch, and wrote my matrices to be left multiplying (eg, in a MLP layer with 4000 neurons and a 1000 dimensional residual stream, W_in.shape==[4000, 1000] and neuron_pre_act = einsum("nm,bm->bn", W_in, residual_stream)). I notice that in most implementations, eg nn.Linear, matrices are right facing, and this lets you use notation like neuron_pre_act = residual_stream @ W_in.

I’m curious if anyone knows why this convention is there, and how much it matters for eg performance? Based on some preliminary tests, right-facing is a performance improvement, but seems to vary a ton in ways I don’t fully understand, eg in some contexts it’s a minor 0.5% speed improvement, in others it’s a major 5-10%?!

Hi @Neel_Nanda,

Can you share a minimal reproducible example?

  1. When you compare torch.einsum with nn.Linear make sure you have nn.Linear(bias=False) otherwise the operations aren’t equivalent.

  2. When you are measuring times for code snippets, make sure that you synchronize torch (torch.cuda.synchronize) before calling time.time() otherwise you won’t record the whole runtime of an operation but only its call (especially if ran on the GPU). There’s more info about it here.

Minimal counter-example

batch = 3 * 10**4
residual = 3 * 10**4
d_mlp = 3 * 10**4

input_mat = torch.randn(batch, residual, device='cuda')
right_mat = torch.randn(residual, d_mlp, device='cuda')
left_mat = torch.randn(d_mlp, residual, device='cuda')

print("Cuda warmup")
_ = (input_mat @ right_mat)
torch.cuda.synchronize()

start_time = time.time()
_ = (input_mat @ right_mat)
torch.cuda.synchronize()
print("Time taken right mul: ", time.time() - start_time)

start_time = time.time()
_ = (input_mat @ left_mat.T)
torch.cuda.synchronize()
print("Time taken left mul transpose: ", time.time() - start_time)

start_time = time.time()
_ = torch.einsum("rm,br->bm", right_mat, input_mat)
torch.cuda.synchronize()
print("Time taken right mul einsum: ", time.time() - start_time)

start_time = time.time()
_ = torch.einsum("mr,br->bm", left_mat, input_mat)
torch.cuda.synchronize()
print("Time taken left mul einsum: ", time.time() - start_time)

On my A100 80GB GPU this returns:

Cuda warmup
Time taken right mul:  3.2466354370117188
Time taken left mul transpose:  3.315218448638916
Time taken right mul einsum:  3.061365842819214
Time taken left mul einsum:  3.315138339996338

Some more experiments:

In float16:

Cuda warmup
Time taken right mul:  0.23980021476745605
Time taken left mul transpose:  0.2591671943664551
Time taken right mul einsum:  0.2865767478942871
Time taken left mul einsum:  0.27009034156799316

In bfloat16

Cuda warmup
Time taken right mul:  0.23795533180236816
Time taken left mul transpose:  0.25300073623657227
Time taken right mul einsum:  0.2702064514160156
Time taken left mul einsum:  0.2489306926727295

With 10**4 rather than 3*10**4

Cuda warmup
Time taken right mul:  0.11016678810119629
Time taken left mul transpose:  0.11061906814575195
Time taken right mul einsum:  0.11274075508117676
Time taken left mul einsum:  0.11332011222839355

Hi Neel!

The short story: I speculate that pytorch follows a Linear.weight “transpose”
convention for improved performance on “low-end” hardware.

I’m not sure what you mean by “right facing,” but I comment below on what
I believe are the likely roots of pytorch’s choice of Linear implementation.

As an aside, I believe this should be residual_stream @ W_in.T. Based
on your shapes (and your einsum() expression), your version (without the
transpose) will fail because the second dimension of residual_stream
doesn’t match the first dimension of W_in.

As for the weight of Linear, I believe that pytorch chooses to have the
weight multiply the input to Linear from the right so that Linear can
accept inputs with any number of leading “batch” dimensions, e.g., no
batch dimension, a single batch dimension, or a batch and a channel
dimension.

Thus shapes [10], [5, 10], and [5, 3, 10] would all be valid inputs to
Linear (in_features = 10, out_features = 2).

However, torch.nn.Linear (in_features = 10, out_features = 2)
has a weight of shape [2, 10] and you can see by inspection that Linear
multiplies its input from the right by the transpose of its weight, that is, e.g.,
torch.randn (5, 10) @ torch.nn.Linear (10, 2).weight.T.

Why is the matrix that is stored as weight the transpose of the of the matrix
that is used in the actual matrix-matrix multiplication (or, more generally,
the tensor-matrix multiplication)?

Although there have been historical exceptions, row-major is the de facto
standard for matrix storage*. That is to say that as you move sequentially
though the (linear) address space in which a matrix is stored, the column
index varies the most rapidly. Equivalently, individual rows are stored
sequentially (and then chained together one after the other).

It is therefore true that performing matrix-matrix-transpose multiplication
has greater memory locality than does matrix-matrix multiplication because
as you perform the row-column vector-vector dot products of which matrix
multiplication is composed, moving sequentially through a column of the
transposed matrix translates to moving sequentially through the corresponding
row of the untransposed matrix, and therefore the data in transposed matrix’s
column is accessed sequentially in memory.

Whether this memory locality affects performance is highly dependent on
the details of the hardware – size and page size of cache memory, relative
cost of cache hits vs. misses, structure and number of floating-point pipelines,
cpu vs. gpu, and so on. On “small” systems where a row of a “large” matrix
might fit into cache memory, but the whole matrix doesn’t, matrix-matrix
multiplication could be dramatically slower than matrix-matrix-transpose
multiplication, and significant algorithmic effort was made in deciding whether
to store a matrix or its transpose (or even sometimes both).

With modern gpus, this is probably much less likely to matter. I’ll also note
that gpus (as well as most modern cpu pipelines) are optimized to be able
to stride efficiently through memory (that is, with strides greater than one,
rather than purely sequentially), which is what you are doing when you access
the column of a matrix (that is stored in row-major form).

Having said all of that, here are some real-life timings for matrix-matrix vs.
matrix-matrix-transpose multiplication on one particular system.

Here is the test script:

import torch
print (torch.__version__)
print (torch.version.cuda)
print (torch.cuda.get_device_name())

import math
import statistics
import time

_ = torch.manual_seed (2022)

device = 'cpu'
print ('device =', device)

nWarm = 100
nStat = 10
nTime = 100
print ('nStat =', nStat, ', nTime =', nTime)

b = 100
m = 10000
n = 10000
print ('b =', b, ', m =', m, ', n =', n)

a = torch.randn (b, m, device = device)
mr = torch.randn (m, n, device = device)
ml = torch.randn (n, m, device = device)

# warm up
for i in range (nWarm):
    p = a @ mr
    p = a @ ml.T

mr_timing = []
mlT_timing = []

for  i in range (nStat):
    torch.cuda.synchronize()
    tStart = time.time()
    for  i in range (nTime):
        p = a @ mr
    torch.cuda.synchronize()
    tEnd = time.time()
    mr_timing.append (tEnd - tStart)
    
    torch.cuda.synchronize()
    tStart = time.time()
    for  i in range (nTime):
        p = a @ ml.T
    torch.cuda.synchronize()
    tEnd = time.time()
    mlT_timing.append (tEnd - tStart)

mr_mean = statistics.mean (mr_timing)
mr_stderr = statistics.stdev (mr_timing) / math.sqrt (nStat - 1)
mlT_mean = statistics.mean (mlT_timing)
mlT_stderr = statistics.stdev (mlT_timing) / math.sqrt (nStat - 1)

print ('a @ mr  timings:', mr_mean, '+-', mr_stderr)
print ('a @ ml.T timing:', mlT_mean, '+-', mlT_stderr)

Here is its cpu-version output:

1.12.0
11.6
GeForce GTX 1050 Ti
device = cpu
nStat = 10 , nTime = 100
b = 100 , m = 10000 , n = 10000
a @ mr  timings: 12.70658016204834 +- 0.007860937265143848
a @ ml.T timing: 12.256337642669678 +- 0.007302343186920837

And here is its companion gpu-version output:

1.12.0
11.6
GeForce GTX 1050 Ti
device = cuda
nStat = 10 , nTime = 100
b = 100 , m = 10000 , n = 10000
a @ mr  timings: 1.354783058166504 +- 0.0005118169535293032
a @ ml.T timing: 1.3442912817001342 +- 0.0014081416198141895

For this particular system, the performance differences are quite small, and
not necessarily statistically significant. (If significant, they could perhaps
be meaningful at the margins – maybe a percent or so, but certainly not
something like a factor of two).

It’s interesting to note – but I wouldn’t read too much into it – that
matrix-matrix-transpose multiplication outperforms matrix-matrix multiplication
on both the cpu and gpu by a little bit. (Of course, not surprisingly, the gpu
outperforms the cpu by a factor of ten – which one would hope because this
is precisely the kind of thing that should be fast on a gpu.)

Again, any such differences will depend on the hardware, and will likely also
depend on the size and shape of the tensors involved.

So, why does pytorch choose to store its multiplying-from-the-right weight
matrix as its transpose?

It could just be historical habit from the time when the matrix-transpose
optimization almost always mattered for large matrices.

But it could be that on low-end hardware (think inference on an embedded
system) the matrix-transpose optimization really does pay off, while on
“fancy” hardware it doesn’t really matter, one way or the other.

*) Note, it is possible to have a Linear store its weight in column-major
form by setting its weight to the transpose view of a row-major matrix:

lin = torch.nn.Linear (10, 2, bias = False)
lin.weight.shape
wt_column_major = torch.randn (10, 2)
wt_column_major.shape
lin.weight = torch.nn.Parameter (wt_column_major.T)
lin.weight.shape

Best.

K. Frank

2 Likes

Thanks so much, that’s a really helpful answer! And interesting that it mostly doesn’t matter on larger/more modern systems.
I also was completely wrong about nn.Linear and hadn’t noticed the tranpose, so thanks for correcting me on that.