Default PyTorch CUDA stream copy ops surprisingly block other copy ops from other PyTorch non-blocking CUDA stream

Hi experts,

We are working on LLM inference, and notice something surprising from our model inference stack. From our observation, the copy operations (e.g. pytorch.tensor.to()/cpu()/tolist(), for both non_blocking=True/False) would block other copy operations running on other CUDA streams whose cudaStreamNonBlocking is set.

Per CUDA doc, we are not expecting this because it seems that default CUDA stream only blocks other CUDA streams who are not non-blocking.

Here is a quick reproduce script:

import ctypes

import threading
from ctypes import c_int, c_uint, c_void_p

import torch

# Load CUDA runtime library
try:
    cuda_runtime = ctypes.CDLL("libcudart.so")  # Linux
except OSError:
    try:
        cuda_runtime = ctypes.CDLL("cudart64_*.dll")  # Windows
    except OSError:
        cuda_runtime = ctypes.CDLL("libcudart.dylib")  # macOS


# Define CUDA constants
cudaStreamNonBlocking = 0x01


def is_stream_non_blocking(pytorch_stream: torch.cuda.Stream) -> bool:
    """Check if a PyTorch CUDA stream is non-blocking."""
    try:
        # Get the raw CUDA stream handle
        stream_ptr = pytorch_stream.cuda_stream

        # Prepare arguments for cudaStreamGetFlags
        flags = c_uint()

        # Call cudaStreamGetFlags
        result = cuda_runtime.cudaStreamGetFlags(
            c_void_p(stream_ptr), ctypes.byref(flags)
        )

        if result != 0:  # cudaSuccess = 0
            print(f"Error getting stream flags: {result}")
            return False

        return bool(flags.value & cudaStreamNonBlocking)

    except Exception as e:
        print(f"Error checking stream flags: {e}")
        return False


def h2d_func():
    stream_h2d = torch.cuda.Stream()
    t = torch.ones([10000], device=torch.device("cpu"), pin_memory=True)
    with torch.cuda.stream(stream_h2d):
        print(f"stream_h2d is non-blocking: {is_stream_non_blocking(stream_h2d)}")
        for i in range(100000000):
            cuda_t = t.to("cuda", non_blocking=True)


def main() -> None:
    h2d_thread = threading.Thread(target=h2d_func, args=())
    h2d_thread.start()
    list_t = torch.empty([1])
    cuda_t = torch.arange(1000000, device=torch.device("cuda"))
    stream_main = torch.cuda.Stream()
    print(
        f"Current stream is non-blocking: {is_stream_non_blocking(torch.cuda.current_stream())}"
    )
    # with torch.cuda.stream(stream_main):
    for i in range(10000000):
        for j in range(100):
            t = cuda_t.repeat_interleave(10)
            t = cuda_t.repeat_interleave(10)
        list_t = t.to('cpu', non_blocking=True)

    h2d_thread.join()
    print("h2d thread is done")


if __name__ == "__main__":
    main()

Result:

stream_h2d is non-blocking: True
Current stream is non-blocking: False

From the GPU trace, we can confirm that there is some blocking issue:

If we change to let all copy operations run in non-default CUDA streams, we can observe that all copy operations could be run in parallel.

    with torch.cuda.stream(stream_main):
        for i in range(10000000):
            for j in range(100):
                t = cuda_t.repeat_interleave(10)
                t = cuda_t.repeat_interleave(10)
            list_t = t.to('cpu', non_blocking=True)

My question is that what is causing the blocking here?

Trace after using separate CUDA stream: