How to figure out why my backwards pass is so slow?

I am solving an optimization problem with PyTorch and the forward pass is roughly 20-40 times faster than the backward pass. I would like to know what’s the best way to profile just the function loss.backward()

I can do something like this:

        with torch.autograd.profiler.profile(use_cuda=True) as prof:
            loss.backward()

But it only highlights the usage of low level functions and it’s unclear where the actual bottleneck is. Ideally, I would like some kind of line profiler but only for the backward pass, if that makes sense

(1) if you profile both the forward and backward, the profiler trace should also include “forward backward links”: for each backward op, you’ll see a link back to the corresponding forward operation.

(2) if you also pass in with_stack=True to the profiler, you should get a full python stacktrace / icicle view for every forward op

Hopefully the two of those together can tell you “which parts of my model forward correspond to the slow kernels in the backward”

1 Like

Thanks, this seems to almost do what I want.

Just not sure how to interpret the results.

        with torch.autograd.profiler.profile([ProfilerActivity.CPU, ProfilerActivity.CUDA], use_device="cuda", with_stack=True) as prof:
            optimizer.zero_grad()
            with record_function("fw_pass"):
                loss = loss_func(...)
            with record_function("bw_pass"):
                loss.backward() 
        print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
        optimizer.step()
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                fw_pass         3.73%     369.522ms        83.53%        8.275s        8.275s      88.440ms         0.89%        8.266s        8.266s             1  
                                               aten::to         0.33%      32.771ms         2.06%     204.397ms      15.089us     278.376ms         2.81%        1.785s     131.749us         13546  
                                          backward_pass         1.33%     131.403ms        16.20%        1.605s        1.605s      97.348ms         0.98%        1.606s        1.606s             1  
                                         aten::_to_copy         0.84%      83.673ms         1.57%     155.292ms      26.205us        1.370s        13.84%        1.512s     255.129us          5926  
                                           aten::select        12.99%        1.287s        13.41%        1.329s     131.321us     910.680ms         9.20%        1.040s     102.741us         10118  
                                     aten::linalg_solve         0.02%       1.883ms         1.34%     132.464ms     995.972us      42.809ms         0.43%     655.701ms       4.930ms           133  
                                              aten::add         7.52%     745.226ms         7.92%     784.386ms     389.854us     267.842ms         2.71%     618.816ms     307.563us          2012  
                                             aten::item         0.70%      69.572ms         1.23%     121.547ms       8.864us     430.553ms         4.35%     510.962ms      37.261us         13713  
                                     aten::_is_any_true         0.59%      58.656ms         4.13%     409.345ms      35.380us     163.199ms         1.65%     508.181ms      43.922us         11570  
                                        aten::unsqueeze         6.19%     613.161ms         6.37%     631.313ms     178.337us     265.527ms         2.68%     489.635ms     138.315us          3540  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  

Now, this does give some information but I would like the stack view you mentioned. I tried to use

prof.export_stacks("/tmp/profiler_stacks.txt", "self_cuda_time_total") , however this creates an empty file.