Libtorch, CUDAGraph, and indexing

Hi,

I am running into issues when capturing a CUDA graph in libtorch. Specifically, around indexing. The example below fails, with an error along the lines of “operation is not permitted when capturing”. Specifically, the indexing fails

#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDAGraph.h>
#include <torch/torch.h>


int main()
{
    torch::StreamGuard stream_guard{at::cuda::getStreamFromPool()};
    at::cuda::CUDAGraph graph{};
    auto options = torch::TensorOptions{}.device(torch::kCUDA);

    auto x = torch::randn({25, 5}, options);
    c10::List<c10::optional<at::Tensor>> index_list{x < 0};
    x = x.index_put(index_list, - x.index(index_list));
    auto y = x.square();

    x.copy_(torch::randn({25, 5}, options));
    graph.capture_begin();
    c10::List<c10::optional<at::Tensor>> index_list{x < 0};
    x = x.index_put(index_list, - x.index(index_list));
    y = x.square();
    graph.capture_end();
    std::cout << y;
}

Similarly, in Python, this does not work. The snippet below fails with the seemingly random error

Traceback (most recent call last):
  File "/home/jakkes/projects/rl-cpp/test.py", line 16, in <module>
    x[x < 0] = -x[x < 0]
RuntimeError: numel: integer multiplication overflow

or

Traceback (most recent call last):
  File "/home/jakkes/projects/rl-cpp/test.py", line 16, in <module>
    x[x < 0] = -x[x < 0]
RuntimeError: status != cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated INTERNAL ASSERT FAILED at "../c10/cuda/CUDACachingAllocator.cpp":1220, please report a bug to PyTorch.
import torch


if __name__ == "__main__":
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        x = torch.randn(25, 5, device="cuda")
        x[x < 0] = -x[x < 0]
        y = x.square()
    torch.cuda.current_stream().wait_stream(s)
    
    x.copy_(torch.randn(25, 5, device="cuda"))
    g = torch.cuda.CUDAGraph()
    with torch.cuda.graph(g):
        x[x < 0] = -x[x < 0]
        y = x.square()

    for _ in range(3):
        x.copy_(torch.randn(25, 5, device="cuda"))
        g.replay()
        print(y)

In this example, I can use torch::where to go around the indexing issues. However, in the specific use case I am having, I need to make use of the accumulate feature, i.e.

x = x.index_put_({someindicies}, some_values, /* accumulate */ true)

with x being a rather large matrix, I would rather not have to create a big loop of torch::where calls.

Any advice is very much welcomed, thank you.

Indexing with a BoolTensor is disallowed as this operation is synchronizing the code with the host as the result tensor has a data dependency (i.e. the shape depends on the number of True entries).
You would either need to move this operation out of the capture or try to replace it with an indexing operation using indices.

Okay, I see. Thank you.