Torch operation time measurement

I want to measure the torch operation execution time.

device = "cuda:0"
add_1 = torch.randn(1, 64, 1, 1).to(device)

with_stack=True,
with profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], record_shapes=True, profile_memory=True) as prof:
    with record_function("operation"):
        rsqrt = torch.ops.aten.rsqrt.default(add_1)
print(prof.key_averages().table())

This profiler activity gives the following output for execution time.
Self CPU time total: 22.537ms
Self CUDA time total: 1.000us

torch.cuda.current_stream().synchronize()
t0 = time.time()
rsqrt = torch.ops.aten.rsqrt.default(add_1)
torch.cuda.current_stream().synchronize()
t1 = time.time()

This code snippet gives me the time 0.17ms

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
rsqrt = torch.ops.aten.rsqrt.default(add_1)
end.record()
torch.cuda.current_stream().synchronize()

torch.cuda.Event gives me the GPU execution time 0.10 ms.

  1. What is the right approach of measuring torch operation execution time?

  2. I have a model and I want to measure the training time of it.


import torch
import torchvision
import time

device = "cuda:0"
num_iter = 1

model = torchvision.models.resnet50(pretrained=True).to(device)

inputs = torch.randn(1, 3, 224, 224).to("cuda:0")
labels = torch.randn(1, 1000).to("cuda:0")

learning_rate = 0.001
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

torch.cuda.current_stream().synchronize()
t0 = time.time()

optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

torch.cuda.current_stream().synchronize()
t1 = time.time()
print(f"Total Time taken for the model training: {(t1 - t0) / num_iter * 1000} ms")

Is this the right way to measure the training time of a model?
Any kind of help will be appreciated. Thanks.

I don’t know how the built in profiler measures the time, but note that warmup iterations would be needed and you should take the mean or median runtime over multiple iterations. I get approx. the same result for a manual profiling, using CUDA events, and torch.utils.benchmark.Timer:

nb_iters = 100

# warmup
for _ in range(10):
    rsqrt = torch.ops.aten.rsqrt.default(add_1)

torch.cuda.current_stream().synchronize()
t0 = time.perf_counter()
for _ in range(nb_iters):
    rsqrt = torch.ops.aten.rsqrt.default(add_1)
torch.cuda.current_stream().synchronize()
t1 = time.perf_counter()
print((t1 - t0)/nb_iters * 1e6)
# 7.923380035208538

# warmup
for _ in range(10):
    rsqrt = torch.ops.aten.rsqrt.default(add_1)

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
for _ in range(nb_iters):
    rsqrt = torch.ops.aten.rsqrt.default(add_1)
end.record()
torch.cuda.current_stream().synchronize()
print((start.elapsed_time(end))/nb_iters * 1e3)
# 7.798720002174377

import torch.utils.benchmark
t1 = torch.utils.benchmark.Timer(stmt="rsqrt = torch.ops.aten.rsqrt.default(add_1)", globals=globals())
t1.blocked_autorange()
# <torch.utils.benchmark.utils.common.Measurement object at 0x7f27f78b7be0>
# torch.ops.aten.rsqrt.default(add_1)
#   Median: 7.10 us
#   3 measurements, 10000 runs per measurement, 1 thread
1 Like