Traced function with addmm slower than Python

Hello,

I am playing with very simple dense layer implementation using torch.addmm and it seems that torch.jit.trace transforms addmm op to sequence of mm and add ops, leading to performance drop on CPU:

import torch
from torch.autograd import profiler
torch.set_num_threads(1)

def dense_layer(input, w, b):
    return torch.addmm(input=b, mat1=input, mat2=w)

if __name__ == '__main__':
    torch.random.manual_seed(1234)
    a = torch.randn(100000, 10)
    b = torch.randn(10, 10)
    c = torch.randn(10)

    with profiler.profile() as prof:
        for i in range(1000):
            dense_layer(a, b, c)
    print(prof.key_averages().table(sort_by='cpu_time_total', row_limit=5))

    traced = torch.jit.trace(dense_layer, (a, b, c))
    with profiler.profile() as prof2:
        for i in range(1000):
            traced(a, b, c)
    print(prof2.key_averages().table(sort_by='cpu_time_total', row_limit=5))

And the output of this script on EC2 Windows with Intel Xeon E5-2686 v4 @ 2.30GHz:

--------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Name            Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls  
--------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
addmm           99.89%           5.603s           100.00%          5.609s           5.609ms          1000             
expand          0.05%            2.927ms          0.09%            4.772ms          4.772us          1000             
as_strided      0.03%            1.845ms          0.03%            1.845ms          1.845us          1000             
--------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Self CPU time total: 5.609s

-----------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Name         Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls  
-----------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
add          65.54%           5.190s           65.81%           5.212s           5.212ms          1000             
mm           33.91%           2.685s           34.19%           2.708s           2.708ms          1000             
empty        0.29%            23.118ms         0.29%            23.118ms         11.559us         2000             
-----------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Self CPU time total: 7.919s

Same script run on Fedora 32 with Intel Core i7-8700K CPU @ 3.70GHz:

--------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Name            Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls  
--------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
addmm           99.87%           1.863s           100.00%          1.866s           1.866ms          1000             
expand          0.06%            1.164ms          0.10%            1.895ms          1.895us          1000             
as_strided      0.04%            731.363us        0.04%            731.363us        0.731us          1000             
--------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Self CPU time total: 1.866s

-----------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Name         Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls  
-----------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
add          73.35%           1.430s           73.40%           1.431s           1.431ms          1000             
mm           26.50%           516.765ms        26.60%           518.530ms        518.530us        1000             
empty        0.09%            1.726ms          0.09%            1.726ms          0.863us          2000             
-----------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Self CPU time total: 1.950s

On Fedora machine difference is less pronounced, but still takes place. What’s the reason of this transformation? Because at least for CPU both mm and addmm call same gemm so it seems more reasonable to expand vector to cover whole matrix and call gemm afterwards. Is it because of focus on GPU? Is there any way to produce CPU-effective trace for such case?

There is a secret linear function :slight_smile:

def dense_layer(input, w, b):
    return torch.ops.aten.linear(input, w, b)

more seriously, I think that whatever causes addmm to be split could be considered a bug.

Secret function also being split in trace mode :slight_smile:
I will raise the issue then, however this does not look unintentional.

Ah, my bad. If you script that, it works. (even if you trace the scripted function, I think)

Thanks, it does. Seems like scripting plain addmm version still does not work.

Moreover, running torch.jit.script(dense_layer) with the former implementation raises error:

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "<string>", line 3, in dense_layer

      def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: number = 1.0, alpha: number = 1.0):
          return self + mat1.mm(mat2)
                        ~~~~~~~ <--- HERE

      def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor:
RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

Edit: was just using wrong signature without keywords.