Hi,
I am running into issues when capturing a CUDA graph in libtorch. Specifically, around indexing. The example below fails, with an error along the lines of “operation is not permitted when capturing”. Specifically, the indexing fails
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDAGraph.h>
#include <torch/torch.h>
int main()
{
torch::StreamGuard stream_guard{at::cuda::getStreamFromPool()};
at::cuda::CUDAGraph graph{};
auto options = torch::TensorOptions{}.device(torch::kCUDA);
auto x = torch::randn({25, 5}, options);
c10::List<c10::optional<at::Tensor>> index_list{x < 0};
x = x.index_put(index_list, - x.index(index_list));
auto y = x.square();
x.copy_(torch::randn({25, 5}, options));
graph.capture_begin();
c10::List<c10::optional<at::Tensor>> index_list{x < 0};
x = x.index_put(index_list, - x.index(index_list));
y = x.square();
graph.capture_end();
std::cout << y;
}
Similarly, in Python, this does not work. The snippet below fails with the seemingly random error
Traceback (most recent call last):
File "/home/jakkes/projects/rl-cpp/test.py", line 16, in <module>
x[x < 0] = -x[x < 0]
RuntimeError: numel: integer multiplication overflow
or
Traceback (most recent call last):
File "/home/jakkes/projects/rl-cpp/test.py", line 16, in <module>
x[x < 0] = -x[x < 0]
RuntimeError: status != cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated INTERNAL ASSERT FAILED at "../c10/cuda/CUDACachingAllocator.cpp":1220, please report a bug to PyTorch.
import torch
if __name__ == "__main__":
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
x = torch.randn(25, 5, device="cuda")
x[x < 0] = -x[x < 0]
y = x.square()
torch.cuda.current_stream().wait_stream(s)
x.copy_(torch.randn(25, 5, device="cuda"))
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
x[x < 0] = -x[x < 0]
y = x.square()
for _ in range(3):
x.copy_(torch.randn(25, 5, device="cuda"))
g.replay()
print(y)
In this example, I can use torch::where
to go around the indexing issues. However, in the specific use case I am having, I need to make use of the accumulate feature, i.e.
x = x.index_put_({someindicies}, some_values, /* accumulate */ true)
with x
being a rather large matrix, I would rather not have to create a big loop of torch::where
calls.
Any advice is very much welcomed, thank you.