Profiling torch.compile CUDA code

Provided 3 tensors A, B, C of shape (N, N) on CUDA device, I have a simple function which does
A * B + C
To understand torch.compile I have tried to run and profile this code with and without torch.compile
Without torch.compile, the code runs as expected, i.e. there are 2 kernel launches one for element-wise addition (temp = A * B) and another for element-wise multiplication (temp + C)

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           test compile         0.00%       0.000us         0.00%       0.000us       0.000us      46.816ms     12199.90%      46.816ms      46.816ms             1  
                                           test compile         1.30%       1.350ms        99.98%     103.550ms     103.550ms       0.000us         0.00%     383.738us     383.738us             1  
                                              aten::add         1.67%       1.734ms        42.17%      43.679ms      43.679ms     194.109us        50.58%     194.109us     194.109us             1  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     194.109us        50.58%     194.109us     194.109us             1  
                                              aten::mul        16.62%      17.213ms        56.50%      58.522ms      58.522ms     189.629us        49.42%     189.629us     189.629us             1  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     189.629us        49.42%     189.629us     189.629us             1  
                                       cudaLaunchKernel        80.38%      83.253ms        80.38%      83.253ms      41.626ms       0.000us         0.00%       0.000us       0.000us             2  
                                  cudaDeviceSynchronize         0.02%      19.208us         0.02%      19.208us      19.208us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 103.570ms
Self CUDA time total: 383.738us

With torch.compile, when I compile and run for the first time I get these statistics

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
    compile_fx.<locals>.fw_compiler_base (dynamo_timed)         0.00%       0.000us         0.00%       0.000us       0.000us     133.232ms    164244.17%     133.232ms     133.232ms             1  
                                           test compile         0.87%      20.587ms       100.00%        2.357s        2.357s       0.000us         0.00%      81.118us      81.118us             1  
                  _compile.compile_inner (dynamo_timed)         5.44%     128.167ms        99.11%        2.336s        2.336s       0.000us         0.00%      81.118us      81.118us             1  
          OutputGraph.call_user_compiler (dynamo_timed)        80.97%        1.908s        93.03%        2.192s        2.192s       0.000us         0.00%      81.118us      81.118us             1  
          create_aot_dispatcher_function (dynamo_timed)         0.84%      19.879ms        12.06%     284.118ms     284.118ms       0.000us         0.00%      81.118us      81.118us             1  
                                            aten::copy_         0.66%      15.654ms         0.71%      16.787ms     258.256us      81.118us       100.00%      81.118us       1.248us            65  
    compile_fx.<locals>.fw_compiler_base (dynamo_timed)         6.84%     161.163ms         9.72%     229.105ms     229.105ms       0.000us         0.00%      81.118us      81.118us             1  
                                            aten::clone         0.05%       1.138ms         0.56%      13.224ms     322.525us       0.000us         0.00%      64.255us       1.567us            41  
                                       aten::lift_fresh         0.79%      18.681ms         0.93%      21.862ms     295.427us       0.000us         0.00%      64.255us       0.868us            74  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      64.255us        79.21%      64.255us       2.677us            24  
                                               aten::to         0.24%       5.714ms         0.67%      15.848ms     417.043us       0.000us         0.00%      16.863us       0.444us            38  
                                         aten::_to_copy         0.17%       4.092ms         0.43%      10.133ms     422.215us       0.000us         0.00%      16.863us       0.703us            24  
                       Memcpy HtoD (Pageable -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      16.863us        20.79%      16.863us       0.703us            24  
                               TorchDynamo Cache Lookup         0.00%       3.039us         0.00%       3.039us       0.760us       0.000us         0.00%       0.000us       0.000us             4  
                                  cudaStreamIsCapturing         0.00%      25.294us         0.00%      25.294us       3.613us       0.000us         0.00%       0.000us       0.000us             7  
                                    aten::empty_strided         0.33%       7.718ms         0.33%       7.718ms      26.432us       0.000us         0.00%       0.000us       0.000us           292  
                                           aten::detach         0.02%     582.685us         0.31%       7.316ms      25.316us       0.000us         0.00%       0.000us       0.000us           289  
                                                 detach         0.14%       3.265ms         0.26%       6.163ms     106.254us       0.000us         0.00%       0.000us       0.000us            58  
                                            aten::empty         1.07%      25.186ms         1.17%      27.655ms     119.720us       0.000us         0.00%       0.000us       0.000us           231  
                                             cudaMalloc         0.01%     253.339us         0.01%     253.339us     253.339us       0.000us         0.00%       0.000us       0.000us             1  
                                              aten::mul         0.25%       5.918ms         0.72%      16.917ms       3.383ms       0.000us         0.00%       0.000us       0.000us             5  
                                             prims::mul         0.14%       3.220ms         0.22%       5.136ms       5.136ms       0.000us         0.00%       0.000us       0.000us             1  
                                   aten::empty_permuted         0.08%       1.797ms         0.08%       1.995ms     665.058us       0.000us         0.00%       0.000us       0.000us             3  
                                       aten::as_strided         0.06%       1.371ms         0.06%       1.459ms     182.417us       0.000us         0.00%       0.000us       0.000us             8  
                                              aten::add         0.14%       3.378ms         0.18%       4.324ms     864.771us       0.000us         0.00%       0.000us       0.000us             5  
                                             prims::add         0.01%     170.167us         0.01%     212.438us     212.438us       0.000us         0.00%       0.000us       0.000us             1  
                                          aten::detach_         0.03%     725.800us         0.03%     760.008us      20.541us       0.000us         0.00%       0.000us       0.000us            37  
                                                detach_         0.00%      34.208us         0.00%      34.208us       0.925us       0.000us         0.00%       0.000us       0.000us            37  
                                            aten::alias         0.01%     331.017us         0.02%     570.771us     570.771us       0.000us         0.00%       0.000us       0.000us             1  
                                             aten::view         0.01%     150.991us         0.01%     239.754us     239.754us       0.000us         0.00%       0.000us       0.000us             1  
                                         prims::view_of         0.00%      74.287us         0.00%      88.763us      88.763us       0.000us         0.00%       0.000us       0.000us             1  
                                             aten::set_         0.09%       2.227ms         0.09%       2.227ms     222.736us       0.000us         0.00%       0.000us       0.000us            10  
                                        cudaMemcpyAsync         0.04%     957.462us         0.04%     957.462us      19.947us       0.000us         0.00%       0.000us       0.000us            48  
                                  cudaStreamSynchronize         0.01%     174.836us         0.01%     174.836us       7.285us       0.000us         0.00%       0.000us       0.000us            24  
                                         aten::randperm         0.20%       4.798ms         0.47%      11.176ms       2.235ms       0.000us         0.00%       0.000us       0.000us             5  
                                            aten::slice         0.25%       5.887ms         0.40%       9.515ms       2.379ms       0.000us         0.00%       0.000us       0.000us             4  
                                        aten::index_add         0.01%     297.864us         0.07%       1.638ms       1.638ms       0.000us         0.00%       0.000us       0.000us             1  
                                        aten::index_put         0.05%       1.204ms         0.10%       2.468ms     822.563us       0.000us         0.00%       0.000us       0.000us             3  
                                       aten::empty_like         0.00%      99.503us         0.01%     145.478us      48.493us       0.000us         0.00%       0.000us       0.000us             3  
                                            aten::index         0.07%       1.663ms         0.22%       5.269ms       2.634ms       0.000us         0.00%       0.000us       0.000us             2  
                                        aten::new_empty         0.04%     844.291us         0.04%     941.363us     470.681us       0.000us         0.00%       0.000us       0.000us             2  
                        compile_fx_inner (dynamo_timed)         0.04%     836.516us         0.04%     836.516us     836.516us       0.000us         0.00%       0.000us       0.000us             1  
                                  Torch-Compiled Region         0.01%     262.690us         0.01%     263.364us     263.364us       0.000us         0.00%       0.000us       0.000us             1  
                                  cudaDeviceSynchronize         0.00%      14.106us         0.00%      14.106us      14.106us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.357s
Self CUDA time total: 81.118us

Now if I create new matrices A, B, C and run the profiling again I get

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           test compile         0.60%       1.547ms        99.99%     259.521ms     259.521ms             1  
                               TorchDynamo Cache Lookup         0.01%      23.106us         0.01%      23.106us      23.106us             1  
                  _compile.compile_inner (dynamo_timed)        66.16%     171.720ms        99.36%     257.887ms     257.887ms             1  
                                  cudaStreamIsCapturing         0.01%      25.576us         0.01%      25.576us       4.263us             6  
                                    aten::empty_strided         4.11%      10.656ms         4.11%      10.656ms     313.399us            34  
                                           aten::detach         0.08%     214.091us         0.54%       1.391ms      27.824us            50  
                                                 detach         0.43%       1.110ms         0.45%       1.177ms      42.039us            28  
                                              aten::mul        10.17%      26.409ms        11.04%      28.659ms       7.165ms             4  
                                            aten::empty         0.83%       2.145ms         0.83%       2.145ms     102.150us            21  
                                              aten::add         7.09%      18.409ms         7.98%      20.708ms       5.177ms             4  
          OutputGraph.call_user_compiler (dynamo_timed)         1.42%       3.679ms        16.59%      43.057ms      43.057ms             1  
          create_aot_dispatcher_function (dynamo_timed)         5.77%      14.983ms        14.67%      38.071ms      38.071ms             1  
                                            aten::clone         0.08%     220.023us         0.23%     601.647us      35.391us            17  
                                            aten::copy_         0.08%     195.953us         0.08%     195.953us      11.527us            17  
                                               aten::to         0.03%      86.083us         0.03%      86.083us       6.622us            13  
                                       aten::lift_fresh         2.26%       5.855ms         2.57%       6.676ms     256.776us            26  
                                          aten::detach_         0.08%     209.552us         0.09%     221.823us      17.063us            13  
                                                detach_         0.00%      12.271us         0.00%      12.271us       0.944us            13  
    compile_fx.<locals>.fw_compiler_base (dynamo_timed)         0.44%       1.150ms         0.75%       1.950ms       1.950ms             1  
                        compile_fx_inner (dynamo_timed)         0.31%     799.775us         0.31%     799.775us     799.775us             1  
                                       aten::empty_like         0.00%       7.427us         0.01%      18.140us       9.070us             2  
                                  Torch-Compiled Region         0.02%      64.157us         0.02%      64.157us      64.157us             1  
                                  cudaDeviceSynchronize         0.01%      29.074us         0.01%      29.074us      29.074us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 259.550ms

In both cases one thing I observe is that aten::add, aten::mul and prims::add and prims::mul all happen on CPU, even though the martices are on GPU.

Am I misunderstanding something from these profiling results?

Here is the code to reproduce:

import torch
import torch.profiler as profiler

n = 16000
A = torch.randn((n, n), device='cuda')
B = torch.randn((n, n), device='cuda')
C = torch.randn((n, n), device='cuda')
D = torch.zeros((n, n), device='cuda')

def test():
    D = A * B + C

test_c = torch.compile(test, fullgraph=True, mode="max-autotune")

with torch.profiler.profile() as prof:

    with torch.profiler.record_function("test compile"):
        test_c()

prof.export_chrome_trace("trace_compile.json")
print(prof.key_averages().table(sort_by="cuda_time_total"))

Return the output as otherwise torch.compile might be smart enough to eliminate dead code.
Adding return D I see:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             Torch-Compiled Region: 4/0         0.79%     143.198us        32.11%       5.841ms       5.841ms       4.745ms        38.02%      12.480ms      12.480ms             1  
                                   aten::_foreach_copy_         6.85%       1.247ms        31.08%       5.655ms       5.655ms       7.731ms        61.95%       7.731ms       7.731ms             1  
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us       7.731ms        61.95%       7.731ms     208.958us            37  
                             triton_poi_fused_add_mul_0         0.00%       0.000us         0.00%       0.000us       0.000us       4.745ms        38.02%       4.745ms       4.745ms             1  
                                            aten::fill_         0.10%      18.586us         0.15%      27.142us      13.571us       3.296us         0.03%       3.296us       1.648us             2  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       3.296us         0.03%       3.296us       1.648us             2  
                               TorchDynamo Cache Lookup         0.11%      19.547us         0.11%      19.547us      19.547us       0.000us         0.00%       0.000us       0.000us             1  
                                       cudaLaunchKernel        24.28%       4.417ms        24.28%       4.417ms     113.248us       0.000us         0.00%       0.000us       0.000us            39  
                                  cudaStreamIsCapturing         0.01%       1.533us         0.01%       1.533us       0.767us       0.000us         0.00%       0.000us       0.000us             2  
                                        cudaGraphLaunch         0.08%      14.097us         0.08%      14.097us      14.097us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 18.194ms
Self CUDA time total: 12.480ms
1 Like

Thank you! Just one last quick question: is there any specific reason why PyTorch doesn’t perform kernel fusion? For instance, an optimization could be to compute A[i]∗B[i]+C[i] using a fused multiply-add kernel, so only one kernel would launch.

Or is it that, regardless of optimizations, any fused operations still require custom CUDA kernels to be written? So, if PyTorch doesn’t already have a fused kernel for a particular operation, it won’t be able to optimize it.

Edit: I just observed that the profiling has shown triton_poi_fused_add_mul_0. But then is my understanding correct about kernel optimization i.e. not all possible optimized kernels are created / generated by pytorch?

Is a fused kernel generated by Inductor.

Yes, I don’t think any compiler can generate all possible optimized kernels as it would have solved the code generation for DL.

1 Like