CUBLAS_STATUS_NOT_INITIALIZED when running CudaGRAPH

Hi,

I am attempting to make use of the CUDAGraph API in libtorch. However, I run into the following error

terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`

Sample code

#include <torch/torch.h>
#include <ATen/cuda/CUDAGraph.h>
#include <c10/cuda/CUDAStream.h>


torch::Tensor linear_sigmoid(const torch::Tensor &A, const torch::Tensor &x, const torch::Tensor &b)
{
    return torch::sigmoid(A.matmul(x) + b);
}


int main()
{
    torch::StreamGuard stream_guard{at::cuda::getStreamFromPool(true)};
    auto x = torch::randn({5, 5, 1}).cuda();
    auto A = torch::randn({5, 2, 5}).cuda();
    auto b = torch::randn({5, 2, 1}).cuda();

    at::cuda::CUDAGraph graph{};
    graph.capture_begin();
    auto output = linear_sigmoid(A, x, b);
    graph.capture_end();
    std::cout << output << "\n";

    x.copy_(torch::randn({5, 5, 1}).cuda());
    A.copy_(torch::randn({5, 2, 5}).cuda());
    b.copy_(torch::randn({5, 2, 1}).cuda());

    graph.replay();
    std::cout << output << "\n";
}

I don’t see warmup iterations being done here, which are required (e.g., see Accelerating PyTorch with CUDA Graphs | PyTorch).

Amazing, thank you for the quick reply.

Updated code snippet for anybody coming here in there future:

#include <torch/torch.h>
#include <ATen/cuda/CUDAGraph.h>
#include <c10/cuda/CUDAStream.h>


torch::Tensor linear_sigmoid(const torch::Tensor &A, const torch::Tensor &x, const torch::Tensor &b)
{
    return torch::sigmoid(A.matmul(x) + b);
}


int main()
{
    torch::StreamGuard stream_guard{at::cuda::getStreamFromPool()};
    auto x = torch::randn({5, 5, 1}).cuda();
    auto A = torch::randn({5, 2, 5}).cuda();
    auto b = torch::randn({5, 2, 1}).cuda();
    auto output = linear_sigmoid(A, x, b);

    at::cuda::CUDAGraph graph{};
    graph.capture_begin();
    output = linear_sigmoid(A, x, b);
    graph.capture_end();
    std::cout << output.mean() << "\n";

    for (int i = 0; i < 10; i++) {
        x.copy_(torch::randn({5, 5, 1}).cuda());
        A.copy_(torch::randn({5, 2, 5}).cuda());
        b.copy_(torch::randn({5, 2, 1}).cuda());
        graph.replay();
        std::cout << output.mean() << "\n";
    }
}