Speeding up backward pass

Hello,

I am finding low GPU utilization (20%) for my pytorch program. I am profiling with the torch profiler, and I see data like this (see details at the bottom):

Self CPU time total: 8.252s

Self CUDA time total: 547.091ms

It seems like the CPU is doing a lot more work (8.2 s) than the GPU (0.5s). In particular, the CompiledFunctionBackward step seems to execute mostly on the CPU, while the GPU is not touched.

Am I misinterpreting the results? Or is it really true that the backward pass happens mostly on the CPU, and the GPU is sitting there waiting? If so, what could be the cause, and how to debug further?

I also loaded the trace.json into tensorboard, and it claims that there are no problems re. data loading. So the current theory is the problem is somehow in torch.compile(), which makes the backward pass run on the CPU?

If I disable torch.compile(), the model trains slower, although I do achieve higher (~50%) GPU utilization. I am not sure what’s going on. Any help is appreciated!

                                               Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  Total KFLOPs  
                                      ProfilerStep\*        32.62%        2.691s        40.52%        3.343s     668.682ms       0.000us         0.00%      31.261ms       6.252ms           0 B           0 B      -4.50 KB      50.29 MB             5            --  

                                      backward_pass        27.43%        2.263s        27.43%        2.263s     452.678ms       0.000us         0.00%       6.145us       1.229us      -1.29 KB      -1.29 KB     -15.02 GB     -15.02 GB             5            --  

                           CompiledFunctionBackward        24.53%        2.025s        26.67%        2.201s     440.103ms       0.000us         0.00%     368.115ms      73.623ms      -1.17 KB      -1.17 KB     -15.02 GB     -21.10 GB             5            --  

                                   CompiledFunction         7.79%     642.705ms         8.74%     721.003ms     144.201ms       0.000us         0.00%     145.911ms      29.182ms       1.17 KB           0 B      15.09 GB      14.82 GB             5            --  

                                      training_step         1.31%     108.125ms        32.35%        2.669s     533.830ms       0.000us         0.00%     124.372ms      24.874ms           0 B           0 B           0 B     -82.95 MB             5            --  

                          Optimizer.step#AdamW.step         0.94%      77.833ms         1.75%     144.124ms      28.825ms       0.000us         0.00%       5.007ms       1.001ms           0 B         -40 B           0 B     -71.61 MB             5            --  

                                     cuLaunchKernel         0.66%      54.710ms         0.68%      56.275ms       8.057us       0.000us         0.00%       0.000us       0.000us           0 B           0 B           0 B           0 B          6985            --  

                                           aten::mm         0.44%      36.274ms         0.65%      53.863ms      30.007us      76.160ms        13.92%      76.160ms      42.429us           0 B           0 B           0 B           0 B          1795  1023490122.880  

                                   cudaLaunchKernel         0.29%      24.178ms         0.29%      24.200ms       7.401us       0.000us         0.00%       0.000us       0.000us           0 B           0 B           0 B           0 B          3270            --  

                                        aten::addmm         0.29%      24.078ms         0.42%      34.796ms      41.424us      34.870ms         6.37%      34.870ms      41.512us           0 B           0 B           0 B      -2.00 MB           840  489941548.160  

                                       forward_pass         0.25%      20.316ms         9.26%     764.147ms     152.829ms       0.000us         0.00%     146.693ms      29.339ms       1.29 KB         120 B      15.09 GB     -13.67 MB             5            --  

                                aten::empty_strided         0.16%      12.811ms         0.16%      12.811ms       4.093us       0.000us         0.00%       0.000us       0.000us           0 B           0 B     122.96 MB     122.96 MB          3130            --  

autograd::engine::evaluate_function: torch::autograd… 0.15% 12.275ms 0.36% 29.356ms 10.071us 0.000us 0.00% 0.000us 0.000us 0 B 0 B 0 B 0 B 2915 –

                    torch::autograd::AccumulateGrad         0.12%      10.224ms         0.21%      17.081ms       5.860us       0.000us         0.00%       0.000us       0.000us           0 B           0 B           0 B           0 B          2915            --  

                                         aten::item         0.11%       9.346ms         0.13%      10.506ms       1.800us       0.000us         0.00%       0.000us       0.000us           0 B           0 B           0 B           0 B          5835            --  

                                aten::\_foreach_add\_         0.11%       8.706ms         0.15%      12.359ms     617.965us     383.295us         0.07%     383.295us      19.165us           0 B           0 B           0 B           0 B            20            --  

                                aten::\_foreach_norm         0.10%       8.004ms         0.10%       8.453ms       1.691ms     503.327us         0.09%     509.633us     101.927us           0 B           0 B       1.42 MB       1.40 MB             5            --  

                         Torch-Compiled Region: 0/0         0.09%       7.810ms         8.84%     729.778ms     145.956ms       0.000us         0.00%     145.911ms      29.182ms       1.17 KB           0 B      15.09 GB           0 B             5            --  

                                     mds_batch_load         0.08%       6.923ms         0.16%      13.035ms       2.607ms       0.000us         0.00%       2.742ms     548.355us           0 B           0 B      31.05 MB           0 B             5            --  

                                aten::\_foreach_sqrt         0.08%       6.862ms         0.23%      18.725ms       1.873ms     685.636us         0.13%     685.636us      68.564us           0 B           0 B      71.61 MB           0 B            10            --  

                            triton_per_fused_sum_13         0.08%       6.483ms         0.12%      10.010ms      19.250us     885.806us         0.16%     885.806us       1.703us           0 B           0 B           0 B           0 B           520            --  

                                aten::\_foreach_mul\_         0.08%       6.477ms         0.09%       7.499ms     374.939us       1.305ms         0.24%       1.305ms      65.267us           0 B           0 B           0 B           0 B            20            --  

                                    aten::unsqueeze         0.08%       6.299ms         0.10%       8.462ms       2.835us       0.000us         0.00%       0.000us       0.000us           0 B           0 B           0 B           0 B          2985            --  

                Optimizer.zero_grad#AdamW.zero_grad         0.06%       5.307ms         0.06%       5.307ms       1.061ms       0.000us         0.00%       0.000us       0.000us           0 B           0 B     -71.61 MB     -71.61 MB             5            --  

                            aten::\_foreach_addcdiv\_         0.06%       5.151ms         0.07%       5.917ms     591.653us       1.062ms         0.19%       1.062ms     106.224us           0 B           0 B           0 B           0 B            10            --  

                            aten::\_foreach_addcmul\_         0.06%       4.986ms         0.07%       5.772ms     577.246us     756.864us         0.14%     756.864us      75.686us           0 B           0 B           0 B           0 B            10            --  

                           TorchDynamo Cache Lookup         0.06%       4.587ms         0.06%       4.587ms     917.399us       0.000us         0.00%       0.000us       0.000us           0 B           0 B           0 B           0 B             5            --  

                                aten::\_foreach_div\_         0.05%       4.390ms         0.06%       5.014ms     501.430us     665.313us         0.12%     665.313us      66.531us           0 B           0 B           0 B           0 B            10            --  

                                        aten::empty         0.05%       4.064ms         0.05%       4.066ms       5.020us       0.000us         0.00%       0.000us       0.000us       1.21 KB       1.21 KB      14.29 GB      14.29 GB           810            --  

                                        aten::stack         0.05%       3.857ms         0.15%      12.023ms       2.405ms       0.000us         0.00%      50.817us      10.163us           0 B           0 B      12.50 KB           0 B             5            --  

                                       aten::detach         0.05%       3.815ms         0.09%       7.025ms       2.385us       0.000us         0.00%       0.000us       0.000us           0 B           0 B           0 B           0 B          2945            --  

                                   aten::as_strided         0.04%       3.440ms         0.04%       3.440ms       0.784us       0.000us         0.00%       0.000us       0.000us           0 B           0 B           0 B           0 B          4390            --  

                                         aten::add\_         0.04%       3.302ms         0.04%       3.461ms       1.179us      29.536us         0.01%      29.536us       0.010us           0 B           0 B           0 B           0 B          2935            --  

                                    cudaMemcpyAsync         0.04%       3.283ms         0.04%       3.283ms       9.247us       0.000us         0.00%       0.000us       0.000us           0 B           0 B           0 B           0 B           355            --  

                                             detach         0.04%       3.210ms         0.04%       3.210ms       1.090us       0.000us         0.00%       0.000us       0.000us           0 B           0 B           0 B           0 B          2945            --  

                aten::\_efficient_attention_backward         0.04%       3.180ms         0.16%      13.116ms     174.877us     172.808ms        31.59%     176.223ms       2.350ms           0 B           0 B       6.08 GB      -8.15 GB            75            --  

                            triton_red_fused_sum_12         0.04%       2.949ms         0.05%       4.511ms      20.506us       2.312ms         0.42%       2.312ms      10.509us           0 B           0 B           0 B           0 B           220            --  

                                    aten::transpose         0.03%       2.580ms         0.04%       3.439ms       3.527us       0.000us         0.00%       0.000us       0.000us           0 B           0 B           0 B           0 B           975            --  

      aten::\_scaled_dot_product_efficient_attention         0.03%       2.558ms         0.09%       7.681ms     102.413us       0.000us         0.00%      69.165ms     922.196us       1.17 KB           0 B     275.80 MB           0 B            75            --  

                            triton_poi_fused_sum_39         0.03%       2.378ms         0.04%       3.550ms      21.512us     225.279us         0.04%     225.279us       1.365us           0 B           0 B           0 B           0 B           165            --  

                                  aten::result_type         0.03%       2.279ms         0.03%       2.279ms       0.146us       0.000us         0.00%       0.000us       0.000us           0 B           0 B           0 B           0 B         15595            --  

aten::_scaled_dot_product_efficient_attention_backwa… 0.03% 2.220ms 0.21% 17.257ms 230.090us 0.000us 0.00% 176.223ms 2.350ms 0 B 0 B 6.08 GB 0 B 75 –

                        triton_per_fused_cat_sum_46         0.03%       2.074ms         0.04%       3.215ms      18.910us     274.334us         0.05%     274.334us       1.614us           0 B           0 B           0 B           0 B           170            --  

                                          aten::bmm         0.02%       2.019ms         0.05%       3.777ms      41.965us       1.839ms         0.34%       1.905ms      21.168us           0 B           0 B           0 B      -1.82 MB            90    217128.960  

                                          aten::mul         0.02%       2.005ms         0.04%       3.353ms      24.837us       1.177ms         0.22%       1.177ms       8.721us           0 B           0 B     211.49 MB     211.49 MB           135     55436.845  

                        triton_per_fused_cat_sum_38         0.02%       1.977ms         0.04%       3.028ms      18.926us     264.735us         0.05%     264.735us       1.655us           0 B           0 B           0 B           0 B           160            --  

                            triton_per_fused_sum_33         0.02%       1.939ms         0.04%       2.961ms      21.149us     301.180us         0.06%     301.180us       2.151us           0 B           0 B           0 B           0 B           140            --  

autograd::engine::evaluate_function: CompiledFunctio… 0.02% 1.921ms 26.69% 2.202s 440.487ms 0.000us 0.00% 368.115ms 73.623ms -1.17 KB 0 B -15.02 GB -157.50 KB 5 –

                            triton_poi_fused_mul_80         0.02%       1.868ms         0.03%       2.739ms      26.087us     153.092us         0.03%     153.092us       1.458us           0 B           0 B           0 B           0 B           105            --  

                                          aten::sum         0.02%       1.753ms         0.03%       2.590ms      27.258us     983.075us         0.18%     983.075us      10.348us           0 B           0 B      13.14 MB      13.14 MB            95            --  

                                        aten::copy\_         0.02%       1.667ms         0.05%       3.736ms      13.343us       2.988ms         0.55%       2.988ms      10.672us           0 B           0 B           0 B           0 B           280            --  

                            Activity Buffer Request         0.02%       1.560ms         0.02%       1.560ms       1.560ms       0.000us         0.00%       0.000us       0.000us           0 B           0 B           0 B           0 B             1            --  

     triton_per_fused_native_layer_norm_backward_83         0.02%       1.543ms         0.03%       2.382ms      19.853us     193.440us         0.04%     193.440us       1.612us           0 B           0 B           0 B           0 B           120            --  

                 aten::\_efficient_attention_forward         0.02%       1.499ms         0.05%       3.852ms      51.353us      69.165ms        12.64%      69.165ms     922.196us       1.17 KB           0 B     275.80 MB           0 B            75            --  

                                           aten::to         0.02%       1.493ms         0.07%       5.530ms       1.686us       0.000us         0.00%       2.742ms       0.836us           0 B           0 B      31.05 MB           0 B          3280            --  

                        triton_red_fused_mul_sum_45         0.02%       1.329ms         0.02%       1.950ms      22.938us     401.563us         0.07%     401.563us       4.724us           0 B           0 B           0 B           0 B            85            --  

                                          aten::div         0.02%       1.297ms         0.03%       2.267ms      26.666us     134.430us         0.02%     134.430us       1.582us           0 B           0 B       2.64 MB       2.64 MB            85            --  

                        triton_red_fused_mul_sum_23         0.01%       1.218ms         0.02%       1.832ms      21.547us     209.572us         0.04%     209.572us       2.466us           0 B           0 B           0 B           0 B            85            --  

     triton_red_fused_native_layer_norm_backward_20         0.01%       1.186ms         0.02%       1.712ms      24.464us       1.157ms         0.21%       1.157ms      16.525us           0 B           0 B           0 B           0 B            70            --  

                          aten::\_local_scalar_dense         0.01%       1.160ms         0.01%       1.160ms       0.199us       0.000us         0.00%       0.000us       0.000us           0 B           0 B           0 B           0 B          5835            --  

                            triton_per_fused_sum_44         0.01%       1.156ms         0.02%       1.765ms      20.763us     139.363us         0.03%     139.363us       1.640us           0 B           0 B           0 B           0 B            85            --  

                                          aten::cat         0.01%       1.091ms         0.02%       1.424ms      56.953us     106.140us         0.02%     106.140us       4.246us           0 B           0 B     647.50 KB     647.50 KB            25            --  

                        triton_red_fused_mul_sum_37         0.01%       1.085ms         0.02%       1.588ms      22.683us       1.154ms         0.21%       1.154ms      16.486us           0 B           0 B           0 B           0 B            70            --  

                                          aten::sub         0.01%       1.031ms         0.02%       1.728ms      23.043us     137.470us         0.03%     137.470us       1.833us           0 B           0 B       5.97 MB       5.97 MB            75            --  

                                     aten::\_to_copy         0.01%       1.011ms         0.05%       4.036ms      21.819us       0.000us         0.00%       2.742ms      14.820us           0 B           0 B      31.05 MB           0 B           185            --  

                            triton_red_fused_sum_84         0.01%     985.793us         0.02%       1.504ms      20.047us     192.419us         0.04%     192.419us       2.566us           0 B           0 B           0 B           0 B            75            --  
                            triton_poi_fused_sum_50         0.01%     862.559us         0.02%       1.314ms      21.896us      79.427us         0.01%      79.427us       1.324us           0 B           0 B           0 B           0 B            60            --  

                             cudaDeviceGetAttribute         0.01%     817.559us         0.01%     817.559us       0.516us       0.000us         0.00%       0.000us       0.000us           0 B           0 B           0 B           0 B          1585            --  

     triton_red_fused_gelu_mul_native_layer_norm_99         0.01%     806.550us         0.01%       1.136ms      32.470us       1.882ms         0.34%       1.882ms      53.776us           0 B           0 B           0 B           0 B            35            --  

                                  Pregraph bytecode         0.01%     791.520us         0.01%     791.520us     158.304us       0.000us         0.00%       0.000us       0.000us           0 B           0 B           0 B           0 B             5            --  
                                      TopkBackward0         0.01%     777.830us         0.01%       1.141ms     228.184us       0.000us         0.00%      30.911us       6.182us           0 B           0 B     157.50 KB           0 B             5            --  

                            triton_per_fused_sum_28         0.01%     758.588us         0.01%       1.142ms      22.846us     137.507us         0.03%     137.507us       2.750us           0 B           0 B           0 B           0 B            50            --  
                           triton_red_fused_sum_120         0.01%     740.045us         0.01%       1.126ms      22.530us      93.279us         0.02%      93.279us       1.866us           0 B           0 B           0 B           0 B            50            --  
                            triton_red_fused_sum_27         0.01%     736.207us         0.01%       1.102ms      22.039us       1.415ms         0.26%       1.415ms      28.306us           0 B           0 B           0 B           0 B            50            --  

                        triton_per_fused_mul_sum_82         0.01%     727.365us         0.01%       1.143ms      22.865us      79.105us         0.01%      79.105us       1.582us           0 B           0 B           0 B           0 B            50            --  

                            triton_poi_fused_cat_26         0.01%     711.025us         0.01%       1.066ms      21.326us       4.508ms         0.82%       4.508ms      90.153us           0 B           0 B           0 B           0 B            50            --  
                        triton_red_fused_mul_sum_40         0.01%     686.598us         0.01%     963.091us      27.517us       5.800ms         1.06%       5.800ms     165.704us           0 B           0 B           0 B           0 B            35            --  

                                        aten::fill\_         0.01%     680.747us         0.02%       1.601ms      15.246us       1.338ms         0.24%       1.338ms      12.746us           0 B           0 B           0 B           0 B           105            --  

                              aten::scatter_reduce\_         0.01%     679.746us         0.02%       1.962ms      78.468us     696.224us         0.13%     869.281us      34.771us           0 B           0 B           0 B     -19.92 MB            25            --  

                                    cudaMemsetAsync         0.01%     669.197us         0.01%     669.197us       1.263us       0.000us         0.00%       0.000us       0.000us           0 B           0 B           0 B           0 B           530            --  

                               aten::\_foreach_lerp\_         0.01%     668.381us         0.01%     980.802us      98.080us     719.876us         0.13%     719.876us      71.988us           0 B           0 B           0 B           0 B            10            --  

                          triton_poi_fused_clone_42         0.01%     650.505us         0.01%     945.943us      27.027us     152.993us         0.03%     152.993us       4.371us           0 B           0 B           0 B           0 B            35            --  

                           triton_per_fused_sum_151         0.01%     628.179us         0.01%     973.691us      19.474us      78.501us         0.01%      78.501us       1.570us           0 B           0 B           0 B           0 B            50            --  

                                       aten::expand         0.01%     613.664us         0.01%     696.666us       8.196us       0.000us         0.00%       0.000us       0.000us           0 B           0 B           0 B           0 B            85            --  

     triton_red_fused_native_layer_norm_backward_81         0.01%     610.006us         0.01%     893.081us      25.517us     145.824us         0.03%     145.824us       4.166us           0 B           0 B           0 B           0 B            35            --  

     triton_per_fused_gelu_mul_native_layer_norm_94         0.01%     591.465us         0.01%     838.442us      33.538us      93.215us         0.02%      93.215us       3.729us           0 B           0 B           0 B           0 B            25            --  
                          triton_poi_fused_clone_91         0.01%     575.625us         0.01%     819.172us      32.767us     143.010us         0.03%     143.010us       5.720us           0 B           0 B           0 B           0 B            25            --  

                    triton_per_fused_exp_mul_sum_41         0.01%     569.421us         0.01%     857.318us      24.495us      77.697us         0.01%      77.697us       2.220us           0 B           0 B           0 B           0 B            35            --  
                                          aten::add         0.01%     553.762us         0.01%     876.529us      25.044us      52.926us         0.01%      52.926us       1.512us           0 B           0 B     792.50 KB     792.50 KB            35       200.010  

                                         aten::mean         0.01%     537.571us         0.01%     802.341us      26.745us      99.393us         0.02%      99.393us       3.313us           0 B           0 B      15.00 KB      14.00 KB            30            --  

      cudaOccupancyMaxActiveBlocksPerMultiprocessor         0.01%     520.772us         0.01%     520.772us       1.447us       0.000us         0.00%       0.000us       0.000us           0 B           0 B           0 B           0 B           360            --  

                            triton_red_fused_sum_43         0.01%     520.472us         0.01%     781.433us      22.327us     126.079us         0.02%     126.079us       3.602us           0 B           0 B           0 B           0 B            35            --  

------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------

Not sure if you are already considering it, but if not, note that changing the shape of your input/output to/from the model and the first run will trigger the code to be traced and recompiled and that can be very slow (slower than an eager run). After the run or runs where the shape changes, it should consistently be more optimized/faster.

Another aspect that can impact the performance, is the use of for loop and if/else statements.