How to get stable `torch.cuda.Event` timings for reliable benchmarking?

[Question] How to get stable torch.cuda.Event timings for reliable benchmarking?

1. The Goal & Problem

I am trying to reliably benchmark a DSA (DeepSeek Sparse Attention) kernel to understand its performance.

However, torch.cuda.Event timings show significant variance (~10-30%) even after multiple warmups, averaging, and other standard practices. This noise makes it difficult to compare optimizations or determine performance bottlenecks accurately.

2. Minimal Example (GEMM)

The actual DSA code is integrated within vLLM. However, this simpler GEMM benchmark demonstrates the same timing instability.

import torch
import statistics

def benchmark_kernel(num_iterations=10):
    device = torch.device("cuda")
    x = torch.randn(10000, 6000, device=device)
    y = torch.randn(6000, 2000, device=device)

    # Warmup
    for _ in range(5):
        _ = torch.matmul(x, y)
    torch.cuda.synchronize()

    # Measurement
    timings = []
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    for _ in range(num_iterations):
        start_event.record()
        # Simulate a kernel with multiple operations
        z = torch.matmul(x, y)
        z = torch.matmul(x, y)
        z = torch.matmul(x, y)
        end_event.record()
        torch.cuda.synchronize()
        timings.append(start_event.elapsed_time(end_event))

    print(f"Mean: {statistics.mean(timings):.3f} ms")
    print(f"Min: {min(timings):.3f} ms, Max: {max(timings):.3f} ms")
    return timings

if __name__ == "__main__":
    benchmark_kernel()

GEMM Results:

Mean: 15.373 ms
Min: 15.352 ms, Max: 15.423 ms

(Note: While this specific GEMM example is relatively stable, the variance is much higher in my actual, more complex sparse attention kernel as shown below.)

3. Real-World Variance (Sparse Attention Kernel)

When measuring my actual target—the DSA forward module in vLLM—the variance is much more pronounced.

Based on 10 measurements (context length 512):

  • Min: 1.601 ms
  • Max: 2.295 ms
  • Mean: 1.831 ms

This ~30% spread between min and max makes it impossible to reliably calculate the indexer/DSA time ratio, which is my primary goal.

4. Environment

  • GPU: NVIDIA H100 80GB HBM3
  • Software: PyTorch 2.8.0+cu128, CUDA 12.8, Driver 550.54.15, Ubuntu 22.04.5 LTS

5. What I’ve Tried

  1. Warmup Iterations (5 iterations)
  2. Averaging Multiple Runs (10 iterations)
  3. Outlier Removal (trimming min/max values)
  4. torch.cuda.synchronize() after each operation.

6. Questions

  1. Is this level of timing variance expected for complex kernels?
  2. Are there better, more stable methods or “best practices” for benchmarking with PyTorch on CUDA beyond what I’ve tried?
  3. Could this be related to GPU power states, scheduler jitter, or something else I can control?

Any advice on how to achieve more stable and reproducible timings would be greatly appreciated. If this is a known issue, pointers to relevant documentation would also be helpful.

Here are several things to consider when timing GPU kernels:

  1. Host overheads: in your case, this should not be an issue as kernels are running for >1ms, but note that event-based timing still includes quite a few overheads from the CPU recording the event, selecting the kernel to launch, and the GPU reporting back. If you use this script for smaller dimensions, I would recommend creating a CUDA graph of operations with the kernel called many times, and use event-based timing for the entire graph. You can also use profilers based on CUPTI or use CUPTI APIs yourself to get significantly more accurate timings for individual kernel launches, but this is generally more difficult than just creating a CUDA graph.
  2. There is some inherent variance from the kernel runs themselves, nothing to do with how you are timing the runs:
    a. Power: in this specific case, you should expect H100 to be power-limited, thus whatever affects the power draw of the GPU at any given time will affect performance here.
    b. Thermals: one thing that often affects power draw is the thermal state of the GPU: it takes some time to ramp up, but it’s possible that if you run this benchmark many times (or simply call the kernel enough times in a row), thermals will have gone up enough s.t. limiters kick in to avoid over-heating the processor. You will notice that if you simply idle for some time between bursts of benchmarking, timing may vary quite a bit. Note that both power and thermals can be controlled by e.g. forcing reduced clocks to their base level. This is what kernel profiling tools like Nsight Compute do by default to avoid variance: most operations simply scale by the clock rate, so once you know about stable base-clock performance, you only need to know approx. how high clocks can get during actual runs to know about absolute performance (but in many cases such as optimization, relative performance is good enough anyway).
    c. Cache effects: due to various reasons, even if all inputs are cached the same way in your runs, the exact access pattern is typically non-deterministic in multi-process systems, leading to some variance. At least, this one is something you can control: again, tools like Nsight Compute keep caches cold, but this might not be realistic. Maybe only a specific input is expected to be cached (partially), in DL workloads it will often be the activations as they were just written by the previous kernel. So in that case, you could rotate what you expect to be the weight tensor and keep the other one fixed (in cache).

It is hard to control even more variables from PyTorch/nvidia-smi, but hopefully the above helps you a little.