Why does the Linear module seems to do unnecessary transposing?

I was looking at the code for torch.nn.Linear(in_features, out_features, bias=True) and it seems that it store the matrix one way but then decides that to compute stuff its necessary to transpose (though the transposing seems it could have been avoided). Why does it store a matrix with dimensions (code http://pytorch.org/docs/master/_modules/torch/nn/modules/linear.html#Linear):

self.weight = Parameter(torch.Tensor(out_features, in_features))

to then go ahead an compute the linear transform as follows:

def forward(self, input):
    return F.linear(input, self.weight, self.bias)

which points to (http://pytorch.org/docs/master/_modules/torch/nn/functional.html#linear):

output = input.matmul(weight.t())

can’t this just be avoided by respecting the order of how the dimensions were given and then do a matrix multiply without the transpose?

def __init__(self, in_features, out_features, bias=True):
....
    self.weight = Parameter(torch.Tensor(in_features,out_features))

then just do:

input.matmul(weight)

and we avoid having to rotate data? Maybe rotating the data doesn’t happen as I think in hardware and the data is not moved at all or something, but it just seems really unnecessary.

Besides even if it wasn’t inefficient, its really weird to me that the data is represented as row vectors (i.e. one row is a data point so the rows span the space of data points in the original raw dimension) but then the weight vector is stored as a D_out x D_in which seems to imply that the D_out is the target dimension where we land to so it seems odd to start thinking about row vectors to then randomly switch to column vectors. Why was this done?

Plus when one does linear.weight it was surprising to me to discover that the shape of the parameters were all switched from what I initially write when I created my linear layer. Maybe its just me but it seems super odd and confusing.

4 Likes

I was also thinking about this, and found this issue:

From what i understand, transposing in forward pass has no overhead. But backward pass will be less efficient if

input.matmul(weight)

why is it less efficient in backward pass?

I’m afraid I cannot answer the question since I don’t know about details of CuDNN implementation. Apparently RNN, which uses same matrix multiple times, exploits pre-transposing for performance improvement.

https://devblogs.nvidia.com/parallelforall/optimizing-recurrent-neural-networks-cudnn-5/

When performing a GEMM the standard BLAS API allows you to transpose either of the two input matrices. Some of the four combinations of transpose/not-transposed run slightly faster or slower than others. Depending on the way that the equations are mapped to the computation, a slower version of the GEMM may be used. By performing a transpose operation up-front on the weight matrix, each step can be made slightly faster. This comes at the cost of the transpose, but that is fairly cheap, so if the transposed matrix is to be used for more than a few iterations it is often worth it.

It would be nice if someone else can elaborate this in more general models, including feed-forward network.

1 Like

Transposition is free for gemm calls, because BLAS libraries (that implement general matrix multiply (gemm)) support both row major and column major matrices, and transpositions.

So it’s okay to have that transpose call, it’s practically a free operation.

4 Likes

Adding to smth’s response, storing the second matrix of a matrix multiplication in transposed form may even increase efficiency. This is because the multiplication routine can access the memory in a more contiguous way, leading to fewer cache misses. See, e.g.,
https://stackoverflow.com/questions/18796801/increasing-the-data-locality-in-matrix-multiplication

Storing the second matrix in transposed form can easily lead to a ~5x speedup in a naive matrix multiplication implementation. The effect will be much smaller in pytorch because the underlying matrix multiplication routine is certainly more clever than the one in the above link.

2 Likes

This is a very good explanation. It explains that by using the transpose, we will have less catch missing. It will be perfect if anyone can futher analysis whehter the transpose have any influence to the backward process.