I am solving an optimization problem with PyTorch and the forward pass is roughly 20-40 times faster than the backward pass. I would like to know what’s the best way to profile just the function loss.backward()
I can do something like this:
with torch.autograd.profiler.profile(use_cuda=True) as prof:
loss.backward()
But it only highlights the usage of low level functions and it’s unclear where the actual bottleneck is. Ideally, I would like some kind of line profiler but only for the backward pass, if that makes sense
(1) if you profile both the forward and backward, the profiler trace should also include “forward backward links”: for each backward op, you’ll see a link back to the corresponding forward operation.
(2) if you also pass in with_stack=True to the profiler, you should get a full python stacktrace / icicle view for every forward op
Hopefully the two of those together can tell you “which parts of my model forward correspond to the slow kernels in the backward”
with torch.autograd.profiler.profile([ProfilerActivity.CPU, ProfilerActivity.CUDA], use_device="cuda", with_stack=True) as prof:
optimizer.zero_grad()
with record_function("fw_pass"):
loss = loss_func(...)
with record_function("bw_pass"):
loss.backward()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
optimizer.step()