I’m using torch.profiler to profile a full training step (forward + backward + optimizer step). I have each of these 3 stages in a record_function
block. However, it seems like the backward pass block is not accounting for CUDA time spent on the relevant kernels called by autograd.
Here is a complete example:
import torch
import torch.nn as nn
from torch.profiler import profile, record_function
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(10000, 100000)
self.fc2 = nn.Linear(100000, 10000)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return x
def run_step(model, x, backward, optimizer=None):
with record_function("forward"):
y = model(x)
if backward:
with record_function("backward"):
loss = torch.mean(y ** 2)
loss.backward()
if optimizer is not None:
with record_function("optimizer"):
optimizer.step()
optimizer.zero_grad()
return y
device = torch.device("cuda")
model = SimpleNet().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
x = torch.randn(10000, 10000).to(device)
with profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(wait=0, warmup=0, active=1, repeat=1),
) as prof:
run_step(model, x, backward=True, optimizer=optimizer)
prof.step()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10000))
When running this, we get:
------------------------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------
------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # o
f Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------
------
ProfilerStep* 0.00% 126.000us 71.34% 2.066s 2.066s 0.000us 0.00% 1.193s 1.193s
1
autograd::engine::evaluate_function: AddmmBackward0 0.00% 89.000us 27.18% 787.008ms 393.504ms 0.000us 0.00% 1.159s 579.507ms
2
AddmmBackward0 0.00% 67.000us 14.17% 410.129ms 205.065ms 0.000us 0.00% 1.158s 578.788ms
2
aten::mm 0.01% 392.000us 14.16% 410.011ms 136.670ms 1.158s 58.29% 1.158s 385.859ms 3
forward 0.01% 314.000us 14.42% 417.382ms 417.382ms 0.000us 0.00% 1.130s 1.130s
1
...
optimizer 0.01% 243.000us 15.15% 438.596ms 438.596ms 0.000us 0.00% 62.893ms 62.893ms
So, AddmmBackward0 is supposedly called during loss.backward(). But it is not accounted in the corresponding record_function
. The row for the backward pass block comes much later, with trivial CUDA total time:
backward 27.97% 809.707ms 41.77% 1.209s 1.209s 0.000us 0.00% 409.000us 409.000us
So it does account for significant CPU time, but almost nothing in the CUDA total column.
Am I doing something silly in the script above, or is there any current difficulty in aggregating the *Backward kernels in the profiler?