Profiler not accounting kernels in the backward pass in record_function

I’m using torch.profiler to profile a full training step (forward + backward + optimizer step). I have each of these 3 stages in a record_function block. However, it seems like the backward pass block is not accounting for CUDA time spent on the relevant kernels called by autograd.

Here is a complete example:

import torch
import torch.nn as nn
from torch.profiler import profile, record_function

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10000, 100000)
        self.fc2 = nn.Linear(100000, 10000)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return x

def run_step(model, x, backward, optimizer=None):
    with record_function("forward"):
        y = model(x)
    if backward:
        with record_function("backward"):
            loss = torch.mean(y ** 2)
            loss.backward()
        if optimizer is not None:
            with record_function("optimizer"):
                optimizer.step()
                optimizer.zero_grad()
    return y

device = torch.device("cuda")
model = SimpleNet().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
x = torch.randn(10000, 10000).to(device)
with profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
             schedule=torch.profiler.schedule(wait=0, warmup=0, active=1, repeat=1),
             ) as prof:
    run_step(model, x, backward=True, optimizer=optimizer)
    prof.step()

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000))

When running this, we get:

------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------
------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # o
f Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------
------  
                                          ProfilerStep*         0.00%     126.000us        71.34%        2.066s        2.066s       0.000us         0.00%        1.193s        1.193s        
     1  
    autograd::engine::evaluate_function: AddmmBackward0         0.00%      89.000us        27.18%     787.008ms     393.504ms       0.000us         0.00%        1.159s     579.507ms        
     2  
                                         AddmmBackward0         0.00%      67.000us        14.17%     410.129ms     205.065ms       0.000us         0.00%        1.158s     578.788ms        
     2  
                                               aten::mm         0.01%     392.000us        14.16%     410.011ms     136.670ms        1.158s        58.29%        1.158s     385.859ms             3  
                                                forward         0.01%     314.000us        14.42%     417.382ms     417.382ms       0.000us         0.00%        1.130s        1.130s        
     1                     
...
                                                optimizer         0.01%     243.000us        15.15%     438.596ms     438.596ms       0.000us         0.00%      62.893ms      62.893ms         

So, AddmmBackward0 is supposedly called during loss.backward(). But it is not accounted in the corresponding record_function. The row for the backward pass block comes much later, with trivial CUDA total time:

                                               backward        27.97%     809.707ms        41.77%        1.209s        1.209s       0.000us         0.00%     409.000us     409.000us        

So it does account for significant CPU time, but almost nothing in the CUDA total column.

Am I doing something silly in the script above, or is there any current difficulty in aggregating the *Backward kernels in the profiler?