Slowdown in CUDA graph execution using cudaMemsetAsync (vs. a fill kernel)

We have a PyTorch-based inference system that heavily uses graph capture+replay. Looking at profiles, every time there is a cudaMemset, there are gaps in the execution that look like this:

In this case, the (4byte) cudaMemsetAsync takes 1.7us, but the gap between kernels is almost 13us. PyTorch itself makes very little use of memset (it seems to prefer fill kernels), but plenty of libraries use memset as part of various kernel prep (here, the TransformerEngine fp8 layernorm kernel).

I was able to reproduce this effect with a small microbenchmark. It looks roughly like this:

z = torch.zeros(1, device="cuda")
x = torch.randn(512, 1024, device="cuda")
w = torch.randn(1024, 1024, device="cuda")

# [warmup, setup, etc], then:
# capture
with torch.cuda.graph(g):
  for _ in range(1000):
    torch.ops.memset_cuda.memset_cuda(z)  # Case 1: use memset
    # z.zero_()  # Case 2: use vectorized_elementwise_kernel
    torch.matmul(x, w.T)

# Measure graph time
torch.cuda.synchronize()
tic = time.time()
g.replay()
torch.cuda.synchronize()
toc = time.time()
print(f"{(toc - tic) * 1000.}ms")

On a single H100-SXM5, I get these measurements for the graph time:

using memset: 34.8ms
using zero_: 33.4ms

This is almost a 5% E2E slowdown due to memset vs. a fill kernel! Looking at the profile, I see similar gaps as above, though their relative size is smaller. Any ideas why this might be? Or any workarounds? In our own code we can prefer the fill kernels, but it is impossible to workaround (eg) cuBLAS’ use of memset. Tagging @ptrblck for NV awareness.

Aside: the call to memset_cuda is a 5-line cpp extension I added that simply calls:

cudaMemsetAsync(x.data_ptr<float>(), 0, x.numel() * sizeof(float),
    at::cuda::getCurrentCUDAStream().stream());
1 Like

This is an interesting observation as this topic claims the opposite at least for eager execution.

Could you post your full extension code to reproduce the issue?

Thanks for taking a look @ptrblck!

Re: eager mode differences, it’s true that the reported operation time for the set-to-zero is lower with memset. There’s a little variation, but staring at some nsys nvprof --print-gpu-trace output, it looks like for this 4-byte memset I see:

  • cudaMemsetAsync: 0.8us
  • vectorized_elementwise_kernel: 1.1-1.3us

But this few-hundred-nanosecond difference is swamped in graph execution by the gaps with memset. (Do CE operations not have fast-dependent-launch?)

I don’t see a way to attach files, so here are the three files to reproduce my original numbers:

1: setup.py

from setuptools import setup
from torch.utils import cpp_extension

setup(
    name="memset_cuda",
    ext_modules=[
        cpp_extension.CUDAExtension(
            "memset_cuda",
            ["memset_cuda.cu"],
        )
    ],
    cmdclass={"build_ext": cpp_extension.BuildExtension},
)

2: memset_cuda.cu

#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <torch/extension.h>

namespace memset_cuda {

void memset_cuda(torch::Tensor& x) {
    AT_DISPATCH_FLOATING_TYPES_AND2(
        at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "memset_cuda", ([&] {
            cudaMemsetAsync(x.data_ptr<scalar_t>(), 0, x.numel() * sizeof(scalar_t),
                            at::cuda::getCurrentCUDAStream().stream());
        }));
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}

TORCH_LIBRARY(memset_cuda, m) { m.def("memset_cuda(Tensor x) -> ()"); }

TORCH_LIBRARY_IMPL(memset_cuda, CUDA, m) { m.impl("memset_cuda", &memset_cuda); }

}  // namespace memset_cuda

3: test.py

import time
import torch
import memset_cuda


def main():
    z = torch.zeros(1, device="cuda")
    x = torch.randn(512, 1024, device="cuda")
    w = torch.randn(1024, 1024, device="cuda")
    use_memset = True

    g = torch.cuda.CUDAGraph()
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())

    # Warmup
    with torch.cuda.stream(s):
        for _ in range(30):
            if use_memset:
                torch.ops.memset_cuda.memset_cuda(z)
            else:
                z.zero_()
            torch.matmul(x, w.T)
    torch.cuda.current_stream().wait_stream(s)

    # Capture
    with torch.cuda.graph(g):
        for _ in range(1000):
            if use_memset:
                torch.ops.memset_cuda.memset_cuda(z)
            else:
                z.zero_()
            torch.matmul(x, w.T)

    # Measure
    torch.cuda.synchronize()
    tic = time.time()
    g.replay()
    torch.cuda.synchronize()
    toc = time.time()
    print(f"{(toc - tic) * 1000.}ms")


if __name__ == "__main__":
    main()

Run / measure:

# Build the extension
python setup.py build
# Run with use_memset=True
# Path will be different if you have a different python version
PYTHONPATH=build/lib.linux-x86_64-cpython-311 python test.py
# Then modify to have use_memset=False and re-run

Fwiw, we are on CUDA 12.1.

Thank you Carl! Let me take a look at it and profile the usage.

1 Like

Hi @ptrblck – gently pinging this one to see if you’ve learned anything – thanks!

Thanks for the ping!
I’ve talked to @eqy and we were speculating if the usage of custom extensions by itself is adding enough CPU overhead (we’ve seen it in the past), but would need to profile it properly and maybe even compare it to a “native” implementation.

Thanks for taking a look. Fwiw, I don’t think this is PyTorch-specific (so maybe this is the wrong forum?). Here is a sample cuda file that does nothing other than use either a kernel or cudaMemsetAsync to zero-out a 4-byte flag:

#include <cuda.h>
#include <stdio.h>

#include <cassert>
#include <chrono>
#include <iostream>

#define CHECK_CUDA(val) check((val), #val, __FILE__, __LINE__)
void check(cudaError_t err, const char* const func, const char* const file, const int line) {
    if (err != cudaSuccess) {
        std::cerr << "CUDA Runtime Error at: " << file << ":" << line << std::endl;
        std::cerr << cudaGetErrorString(err) << " " << func << std::endl;
        std::exit(EXIT_FAILURE);
    }
}

__global__ void zero_out(float* a) { *a = 0; }

int main(int argc, char** argv) {
    if (argc != 2) {
        std::cerr << "Usage: " << argv[0] << " [kernel|memset]" << std::endl;
        std::exit(EXIT_FAILURE);
    }
    bool use_kernel = false;
    if (std::string(argv[1]) == "kernel") {
        use_kernel = true;
    } else if (std::string(argv[1]) != "memset") {
        std::cerr << "Usage: " << argv[0] << " [kernel|memset]" << std::endl;
        std::exit(EXIT_FAILURE);
    }
    CHECK_CUDA(cudaSetDevice(0));
    float* d_z;
    CHECK_CUDA(cudaMalloc(&d_z, sizeof *d_z));
    CHECK_CUDA(cudaMemset(d_z, 0, sizeof *d_z));

    cudaStream_t stream, stream2;
    cudaGraph_t graph;
    CHECK_CUDA(cudaStreamCreate(&stream));
    CHECK_CUDA(cudaStreamCreate(&stream2));
    CHECK_CUDA(cudaDeviceSynchronize());

    // 1) Capture
    CHECK_CUDA(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal));
    for (int i = 0; i < 1000; ++i) {
        if (use_kernel) {
            zero_out<<<1, 1, 0, stream>>>(d_z);
        } else {
            CHECK_CUDA(cudaMemsetAsync(d_z, 0, sizeof *d_z, stream));
        }
    }
    CHECK_CUDA(cudaStreamEndCapture(stream, &graph));
    // 2) Instantiate and warmup
    cudaGraphExec_t graphExec;
    CHECK_CUDA(cudaGraphInstantiate(&graphExec, graph));
    for (int i = 0; i < 5; ++i) {
        CHECK_CUDA(cudaGraphLaunch(graphExec, stream2));
    }
    // 3) Measure
    CHECK_CUDA(cudaDeviceSynchronize());
    auto tic = std::chrono::high_resolution_clock::now();
    CHECK_CUDA(cudaGraphLaunch(graphExec, stream2));
    CHECK_CUDA(cudaDeviceSynchronize());
    auto toc = std::chrono::high_resolution_clock::now();

    auto duration = std::chrono::duration_cast<std::chrono::microseconds>(toc - tic);
    std::cout << "Execution time: " << duration.count() << "us" << std::endl;
    return 0;
}

Supposing this file is test.cu, here are the results running on my own machine (with an H100-SXM5):

> nvcc -O3 -arch=sm_90 -o test test.cu
> ./test memset
Execution time: 1490us
> ./test kernel
Execution time: 1045us

In other words, the average time per-memset is ~1us for the kernel and ~1.5us for memset. If you look at traces, though, you’ll see that an individual memset takes closer to 800ns, so half of the time is gaps in execution. This matches what I saw all the way at the beginning of the thread, though the gaps we see in real applications are much larger (can be as large as 10us) – certainly if they were never more than 1us, that would be great!

1 Like