Cuda.graph with embedding triggering "operation not permitted when stream is capturing"

I’m trying to run torch.nn.functional.embedding(static_input, weight) inside of a cuda graph, but getting “operation not permitted when stream is capturing” error. Based on a bisect, something like the code below (from a more complex repro) worked prior to the following change [allocator] Generalize recording to a pool by zdevito · Pull Request #96542 · pytorch/pytorch · GitHub , but gives me this error thereafter.

This is my simplifed repro, which takes the 2.1.0 RC docs (though they seem to have the same example as 2.0.1 docs) and just adds an F.embedding inside as well:

import torch
import torch.cuda
from torch import nn

g = torch.cuda.CUDAGraph()

# Placeholder input used for capture
static_input = torch.empty((5,), device="cuda", dtype=torch.int32)
weight = torch.empty((6,10), device="cuda")

# Warmup before capture
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for _ in range(3):
        static_output = static_input * 2
        torch.nn.functional.embedding(input=static_input, weight=weight)

torch.cuda.current_stream().wait_stream(s)

# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
with torch.cuda.graph(g):
    static_output = static_input * 2
    torch.nn.functional.embedding(input=static_input, weight=weight)

The error is “RuntimeError: CUDA error: operation not permitted when stream is capturing”. Are there any docs/examples of how I should be using a torch.embedding inside a cuda graph, or how I should think about this? I read through the above commit that broke things and it seems to make sense, but not sure how it’d break the use of an F.embedding that previously worked.

It seems you are initializing the input and weight matrix with torch.empty, which will use uninitialized memory and can easily fail in F.embedding due to indexing errors. Is your code thus even working without CUDA Graphs?

Ah thanks, good spot on uninitialized indices, I missed that in trying to keep as true-to-form as the original demo code. I think it’s crashing before it gets to the indexing though.

But alas, it seems running the same code but s/torch.empty/torch.zeros/g or s/…/torch.ones/` in both places fails similarly: “CUDA error: operation not permitted when stream is capturing” on the call to cudaMalloc.

More detail with compute-sanitizer:



import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import torch
import torch.cuda
from torch import nn

g = torch.cuda.CUDAGraph()

# Placeholder input used for capture
static_input = torch.zeros((5,), device="cuda", dtype=torch.int32)
weight = torch.zeros((6,10), device="cuda")

# Warmup before capture
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for _ in range(3):
        static_output = static_input * 2
        torch.nn.functional.embedding(input=static_input, weight=weight)

torch.cuda.current_stream().wait_stream(s)

# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
with torch.cuda.graph(g):
    static_output = static_input * 2
    torch.nn.functional.embedding(input=static_input, weight=weight)

gives the following output on the main stacktrace. Seems like wrapper_CUDA_mul_Tensor needs to allocate_or_resize_outputs which ends up calling cudaMalloc in a problematic way:

========= Program hit cudaErrorStreamCaptureUnsupported (error 900) due to "operation not permitted when stream is capturing" on CUDA API call to cudaMalloc.
=========     Saved host backtrace up to driver entry point at error
=========     Host Frame: [0x44a116]
=========                in /usr/local/cuda/compat/libcuda.so.1
=========     Host Frame:cudaMalloc [0x5167e]
=========                in /usr/local/cuda-12.1/lib64/libcudart.so.12
=========     Host Frame:c10::cuda::CUDACachingAllocator::Native::NativeCachingAllocator::allocate(unsigned long) const [0x30caf]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libc10_cuda.so
=========     Host Frame:c10::StorageImpl::StorageImpl(c10::StorageImpl::use_byte_size_t, c10::SymInt const&, c10::Allocator*, bool) [0x1133bfa]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:at::TensorBase at::detail::_empty_generic<long>(c10::ArrayRef<long>, c10::Allocator*, c10::DispatchKeySet, c10::ScalarType, c10::optional<c10::MemoryFormat>) [0x1136b72]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:at::detail::empty_generic(c10::ArrayRef<long>, c10::Allocator*, c10::DispatchKeySet, c10::ScalarType, c10::optional<c10::MemoryFormat>) [0x11303d8]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:at::detail::empty_cuda(c10::ArrayRef<long>, c10::ScalarType, c10::optional<c10::Device>, c10::optional<c10::MemoryFormat>) [0xd8f871]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::detail::empty_cuda(c10::ArrayRef<long>, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, c10::optional<c10::MemoryFormat>) [0xd8fa08]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::detail::empty_cuda(c10::ArrayRef<long>, c10::TensorOptions const&) [0xd8fb76]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::(anonymous namespace)::create_out(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::TensorOptions const&) [0x2deb349]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::(anonymous namespace)::structured_mul_out_functional::set_output_raw_strided(long, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::TensorOptions, c10::ArrayRef<at::Dimname>) [0x2f3dc93]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::TensorIteratorBase::allocate_or_resize_outputs() [0x11dde84]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:at::TensorIteratorBase::build(at::TensorIteratorConfig&) [0x11e242b]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:at::TensorIteratorBase::build_borrowing_binary_op(at::TensorBase const&, at::TensorBase const&, at::TensorBase const&) [0x11e3d84]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:at::(anonymous namespace)::wrapper_CUDA_mul_Tensor(at::Tensor const&, at::Tensor const&) [0x2e910e9]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&), &at::(anonymous namespace)::wrapper_CUDA_mul_Tensor>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&> >, at::Tensor (at::Tensor const&, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) [0x2e91268]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so
=========     Host Frame:at::_ops::mul_Tensor::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) [0x1cd96cf]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:torch::autograd::VariableType::(anonymous namespace)::mul_Tensor(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) [0x3aefddc]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&, at::Tensor const&), &torch::autograd::VariableType::(anonymous namespace)::mul_Tensor>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, at::Tensor const&> >, at::Tensor (c10::DispatchKeySet, at::Tensor const&, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&) [0x3af049b]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:at::_ops::mul_Tensor::call(at::Tensor const&, at::Tensor const&) [0x1d31f8f]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
=========     Host Frame:torch::autograd::THPVariable_mul(_object*, _object*, _object*) [0x4a2cb9]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libtorch_python.so
=========     Host Frame:_object* torch::autograd::TypeError_to_NotImplemented_<&torch::autograd::THPVariable_mul>(_object*, _object*, _object*) [0x4a2d6b]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/python3.11/site-packages/torch/lib/libtorch_python.so
=========     Host Frame:Objects/descrobject.c:366:method_vectorcall_VARARGS_KEYWORDS [0x16d6e2]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/libpython3.11.so.1.0
=========     Host Frame:Objects/typeobject.c:1693:vectorcall_maybe.constprop.0 [0x1d252a]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/libpython3.11.so.1.0
=========     Host Frame:Objects/typeobject.c:7421:slot_nb_multiply [0x1d2bd1]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/libpython3.11.so.1.0
=========     Host Frame:Objects/abstract.c:1111:PyNumber_Multiply [0x143f73]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/libpython3.11.so.1.0
=========     Host Frame:Python/ceval.c:5555:_PyEval_EvalFrameDefault [0x1048d8]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/libpython3.11.so.1.0
=========     Host Frame:Python/ceval.c:1154:PyEval_EvalCode [0x25ad9f]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/libpython3.11.so.1.0
=========     Host Frame:Python/pythonrun.c:1733:run_mod [0x2a621d]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/libpython3.11.so.1.0
=========     Host Frame:Python/pythonrun.c:262:PyRun_InteractiveOneObjectEx [0x2a64aa]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/libpython3.11.so.1.0
=========     Host Frame:Python/pythonrun.c:138:_PyRun_InteractiveLoopObject [0x2a7726]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/libpython3.11.so.1.0
=========     Host Frame:Python/pythonrun.c:73:_PyRun_AnyFileObject [0x2a809e]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/libpython3.11.so.1.0
=========     Host Frame:Python/pythonrun.c:106:PyRun_AnyFileExFlags [0x2a812e]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/libpython3.11.so.1.0
=========     Host Frame:Modules/main.c:680:Py_RunMain [0x2c85a1]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/libpython3.11.so.1.0
=========     Host Frame:Modules/main.c:735:Py_BytesMain [0x2c883e]
=========                in /user_dir/.pyenv/versions/3.11.5/lib/libpython3.11.so.1.0

The code works for me using valid inputs:

import torch
import torch.cuda
from torch import nn

g = torch.cuda.CUDAGraph()

# Placeholder input used for capture
static_input = torch.zeros((5,), device="cuda", dtype=torch.int32)
weight = torch.randn((6,10), device="cuda")

# Warmup before capture
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for _ in range(3):
        static_output = static_input * 2
        out = torch.nn.functional.embedding(input=static_input, weight=weight)

torch.cuda.current_stream().wait_stream(s)

# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
with torch.cuda.graph(g):
    static_output = static_input * 2
    out = torch.nn.functional.embedding(input=static_input, weight=weight)

g.replay()
print(out)
# tensor([[ 0.2389,  0.2930, -0.7131,  0.0843,  0.8221,  0.4242,  0.3843,  1.5787,
#           2.1157,  1.8970],
#         [ 0.2389,  0.2930, -0.7131,  0.0843,  0.8221,  0.4242,  0.3843,  1.5787,
#           2.1157,  1.8970],
#         [ 0.2389,  0.2930, -0.7131,  0.0843,  0.8221,  0.4242,  0.3843,  1.5787,
#           2.1157,  1.8970],
#         [ 0.2389,  0.2930, -0.7131,  0.0843,  0.8221,  0.4242,  0.3843,  1.5787,
#           2.1157,  1.8970],
#         [ 0.2389,  0.2930, -0.7131,  0.0843,  0.8221,  0.4242,  0.3843,  1.5787,
#           2.1157,  1.8970]], device='cuda:0')
print(weight)
# tensor([[ 0.2389,  0.2930, -0.7131,  0.0843,  0.8221,  0.4242,  0.3843,  1.5787,
#           2.1157,  1.8970],
#         [ 1.9481, -1.2948, -0.1860, -0.1144, -0.2101,  1.0376, -0.5041,  0.8710,
#          -2.0183, -1.0133],
#         [-0.4531, -0.7180,  0.6278,  1.3979, -1.2674,  0.3704,  1.1185, -0.2874,
#          -0.1480,  0.3355],
#         [-0.5120,  1.0218, -0.2350,  1.0649,  0.0083,  0.0147, -0.7148, -0.2410,
#          -0.3195,  0.4420],
#         [ 0.0666, -0.8115, -0.9631,  0.1666,  1.1810,  0.5338, -0.7471,  2.3101,
#          -0.7206,  1.3824],
#         [ 0.2561, -1.2169,  0.0993, -0.2104,  2.0112,  0.4424,  2.2243,  0.0525,
#           0.6025, -0.9965]], device='cuda:0')

Thank you for bearing with me. I learned that my small repro was not actually replicating the problem I thought it was. Two things:

  • I was running cuda-memcheck trying to figure out where things went sideways. I was running with PYTORCH_NO_CUDA_MEMORY_CACHING=1, which seemed to create problems by preventing caching, causing mallocs within the graph, which is problematic. Should not use that!
  • I was running the code in a python shell, but that actually made things worse since python wanted to print out the tensors, which meant the GPU->CPU copy was also illegal in the cuda graph.

If I run as a python script (not repl), without PYTORCH_NO_CUDA_MEMORY_CACHING=1, then yes the listed code works just fine.

The real bug I was trying to repro, seems to be due to running two streams within a single graph which pytorch made a bit trickier as-of-that-commit, and is entirely unrelated to the F.embedding red herring I found myself chasing down after using the debugging tools inappropriately.

1 Like

Great debugging and thanks for sharing the update!