Torch graph merging with cuda graph does not work

I am trying to merge a large cuda graph with an inference (and eventually) a training cuda graph.

To do so I am using a local version of libtorch that is slightly modified from v2.0.1.

I delete the following:

  // check if debug path is set
  if (!_cuda_graphs_debug) {
    // Now that we've instantiated graph_ into graph_exec_,
    // we don't need graph_ anymore.
    AT_CUDA_CHECK(cudaGraphDestroy(graph_));
    has_graph_ = false;
  } else {
    TORCH_WARN("DEBUG: TORCH_CUDAGRAPHS_DEBUG_PATH detected. graph_ will not be freed until debug_dump is called.");
  }

in aten/src/Aten/cuda/CUDAGraph.cpp
and make the following public in CUDAGraph.h

  cudaGraph_t graph_ = NULL;
  cudaGraphExec_t graph_exec_ = NULL;

To test the graph merge and execution outside of the traditional structure I have a relatively short script.

#include <ATen/cuda/CUDAGraph.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/torch.h>

#include <iostream>

#include "helper_cuda.h"

void printTensorSize(torch::Tensor& tensor) {
    std::cout << "Size: [";
    for (size_t i = 0; i < tensor.sizes().size(); ++i) {
        std::cout << tensor.sizes()[i];
        if (i < tensor.sizes().size() - 1) {
            std::cout << ", ";
        }
    }
    std::cout << "]" << std::endl;
}

void printTensor(torch::Tensor& tensor) {
    if (tensor.dim() != 2) {
        std::cout << "Error: Tensor is not 2-dimensional." << std::endl;
        return;
    }

    for (int64_t i = 0; i < tensor.size(0); ++i) {
        for (int64_t j = 0; j < tensor.size(1); ++j) {
            std::cout << std::setw(10) << std::setprecision(4)
                      << tensor[i][j].item<float>() << " ";
        }
        std::cout << std::endl;
    }
    std::cout << "Address of the data: " << tensor.data_ptr() << std::endl;
}

struct Net : torch::nn::Module {
    torch::nn::Linear linear1, linear2, linear3;
    torch::Tensor preallocated_output_tensor;
    Net(int64_t input, int64_t hidden1, int64_t hidden2, int64_t output, int num_av, float* output_arr, torch::TensorOptions& options)
        : linear1(
              register_module("linear1", torch::nn::Linear(input, hidden1))),
          linear2(
              register_module("linear2", torch::nn::Linear(hidden1, hidden2))),
          linear3(
              register_module("linear3", torch::nn::Linear(hidden2, output))) {
                preallocated_output_tensor = torch::from_blob(output_arr, {num_av, output}, options);
              }

    // TODO: make in place version
    torch::Tensor forward(torch::Tensor x) {
        x = linear1->forward(x);
        torch::relu_(x);

        x = linear2->forward(x);
        torch::relu_(x);

        x = linear3->forward(x);

        auto batch_size = x.size(0);
        x = torch::arange(0, 4, torch::dtype(torch::kFloat32)).repeat({batch_size, 1});
        preallocated_output_tensor = x;
        return x;
    }
};

__global__ void printOutput(int num_avs,float4* output){
    const int tid = threadIdx.x + blockIdx.x * (blockDim.x);
    if(tid == 0){
        printf("cudaptr: 0x%p\n", output);
    }
    if(tid < num_avs){
        printf("tid: %d, Agent Output: %f %f %f %f\n ", tid, output[tid].x, output[tid].y, output[tid].z, output[tid].w);
    }
}

__global__ void generateData(int num_avs, int set,float4* input){
    const int tid = threadIdx.x + blockIdx.x * (blockDim.x);
    if(tid < num_avs){
        input[tid].x = set;
        input[tid].y = set;
        input[tid].z = set;
        input[tid].w = set;
    }
}

int main() {
    torch::Device device(torch::kCUDA);
    int num_av = 10;
    int input_size = 40; // must be divisible by 4
    int output_size = 4; // must be divisible by 4
    float* input_arr, *output_arr;

    // allocate input and output arrays for underlying data in Tensors
    cudaError_t cudaStatus = cudaMallocManaged(&input_arr, input_size*num_av*sizeof(float));
    cudaMemset(input_arr, 0, input_size*num_av*sizeof(float));
    cudaStatus = cudaMallocManaged(&output_arr, output_size*num_av*sizeof(float));
    cudaMemset(output_arr, 0, output_size*num_av*sizeof(float));

    // allocate input and output tensors based on constant address internal arrays
    auto options = torch::TensorOptions().dtype(torch::kFloat32).device(device);
    torch::Tensor input_tensor = torch::from_blob(input_arr, {num_av, input_size}, options);
    torch::Tensor output_tensor = torch::from_blob(output_arr, {num_av, output_size}, options);

    // create network
    Net model(input_size, 120, 20, output_size, num_av, output_arr, options);
    model.to(device);

    // warm start
    torch::Tensor _;
    for (int i = 0; i < 6; i++) {
		_ = model.forward(input_tensor);
        printTensor(output_tensor);
	}

    // capture forward
    at::cuda::CUDAStream torchStream = at::cuda::getStreamFromPool(true);
    at::cuda::setCurrentCUDAStream(torchStream);
    at::cuda::CUDAGraph forward_graph;
	forward_graph.capture_begin();
    _ = model.forward(input_tensor);
    forward_graph.capture_end();

    void* kernelArgs2[2];
	void* kernelArgs3[3];
    cudaGraph_t graph;
    cudaGraphCreate(&graph, 0);
    cudaGraphNode_t model_forward_node, set_input_node, print_output_node;
    cudaKernelNodeParams kernelNodeParams;
    int set_num = 1;

	kernelNodeParams.func = (void*) generateData;
	kernelNodeParams.gridDim = dim3(num_av, 1, 1);
	kernelNodeParams.blockDim = dim3(1, 1, 1);
	kernelNodeParams.sharedMemBytes = 0;
    kernelArgs3[0] = (void*)&num_av;
	kernelArgs3[1] = (void*)&set_num;
	kernelArgs3[2] = (void*)&input_arr;
	kernelNodeParams.kernelParams = kernelArgs3;
	kernelNodeParams.extra = NULL;
	cudaGraphAddKernelNode(&set_input_node, graph,nullptr,0, &kernelNodeParams);

    checkCudaErrors(cudaGraphAddChildGraphNode(&model_forward_node, graph, &set_input_node, 1, forward_graph.graph_));

	kernelNodeParams.func = (void*) printOutput;
	kernelNodeParams.gridDim = dim3(num_av, 1, 1);
	kernelNodeParams.blockDim = dim3(1, 1, 1);
	kernelNodeParams.sharedMemBytes = 0;
    kernelArgs2[0] = (void*)&num_av;
	kernelArgs2[1] = (void*)&output_arr;
	kernelNodeParams.kernelParams = kernelArgs2;
	kernelNodeParams.extra = NULL;
	cudaGraphAddKernelNode(&print_output_node, graph,&model_forward_node,1, &kernelNodeParams);

    // Launch the graph
    cudaGraphExec_t graphExec;
    cudaGraphInstantiate(&graphExec, graph, NULL, NULL, 0);
    cudaGraphLaunch(graphExec, 0);
    cudaDeviceSynchronize();
    printTensor(output_tensor);

    // Launch the graph a second time
    cudaGraphLaunch(graphExec, 0);
    cudaDeviceSynchronize();
    printTensor(output_tensor);

    // Cleanup
    cudaGraphExecDestroy(graphExec);
    cudaGraphDestroy(graph);
    cudaFree(input_arr);
    cudaFree(output_arr);

    return 0;
}

I set a preallocated tensor to try and make sure we have a constant address for the rest of the cuda graph and the underlying data in output_arr. However, the output seems to always be zero or some garbage data.

How would I successfully merge cuda graphs?
Is there a better way to do this without source changes?
How may I better help the community address this? (More Info, PR, etc)

For more context, the graph recording seems to work, if I don’t merge graphs.

This Mnist example works for me (with the torch source changes) . (More involved)

#include <ATen/cuda/CUDAGraph.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/torch.h>

#include "helper_cuda.h"

/**
 * Simple example showing how to capture MLP in a cuda graph,
 * running it through training and evaluation on MNIST data
 */

struct Net : torch::nn::Module {
    torch::nn::Linear linear1, linear2, linear3;

    Net(int64_t input, int64_t hidden1, int64_t hidden2, int64_t output)
        : linear1(register_module("linear1", torch::nn::Linear(input, hidden1))),
          linear2(register_module("linear2", torch::nn::Linear(hidden1, hidden2))),
          linear3(register_module("linear3", torch::nn::Linear(hidden2, output))) {}

    torch::Tensor forward(torch::Tensor x) {
        x = torch::relu(linear1->forward(x));
        x = torch::relu(linear2->forward(x));
        return linear3->forward(x);
    }
};

int main() {
    // ============================ constants ============================ //
    int64_t kBatchSize = 32;
    int64_t kNumberOfEpochs = 10;
    const double learning_rate = 0.001;

    const int input_size = 28 * 28;
    const int hidden1_size = 50;
    const int hidden2_size = 50;
    const int output_size = 10;
    // ========================== end constants ========================== //

    // instantiate network and move to GPU
    torch::Device device(torch::kCUDA);
    Net net(input_size, hidden1_size, hidden2_size, output_size);
    net.to(device);

    // loss function and optimizer
    torch::nn::CrossEntropyLoss loss_fn;
    torch::optim::SGD optimizer(net.parameters(), learning_rate);

    at::cuda::CUDAStream torchStream = at::cuda::getStreamFromPool(true);
    at::cuda::setCurrentCUDAStream(torchStream);
    torchStream.synchronize();

    // =========== DO NOT REMOVE: Necessary to warm start torch! =========== //
    c10::optional<bool> yes = true;
    c10::TensorOptions type = torch::TensorOptions().device(torch::kCUDA).dtype(torch::kInt64);
    c10::TensorOptions option = torch::TensorOptions().device(torch::kCUDA);
    
    // Input and output must exist outside of the scope of capturing and running
    // the graph a side stream must be created to capture cudaGraph
    torch::Tensor warmup = at::randn({kBatchSize, input_size});

    torch::Tensor input_shared = torch::empty_like(warmup, option);

    torch::Tensor target_shared = torch::ones({kBatchSize}, type);

    torch::Tensor output_shared = torch::empty({warmup.size(0), 10}, option);  // Assuming you have 10 output classes

    torch::Tensor output_before = torch::empty({warmup.size(0), 10}, option);

	torch::Tensor loss = torch::empty_like(warmup, option);

    for (int i = 0; i < 3; i++) {
        optimizer.zero_grad(true);
        torch::Tensor warmup_output = net.forward(input_shared);

        torch::Tensor warmup_loss = loss_fn(warmup_output, target_shared);
        warmup_loss.backward();
        optimizer.step();
    }

    torchStream.synchronize();
    // ============================ end warmup ============================ //

    // ======================= capture libtorch graph ======================= //
    at::cuda::CUDAGraph graph;
    optimizer.zero_grad(true);

    graph.capture_begin();

    output_shared = net.forward(input_shared);
    loss = loss_fn(output_shared, target_shared);
    loss.backward();
    optimizer.step();

    graph.capture_end();
    // ============================= end capture ============================= //

    // Load mnist datasets
    auto train_dataset = torch::data::datasets::MNIST(
            "../data", 
            torch::data::datasets::MNIST::Mode::kTrain
        )
        .map(torch::data::transforms::Normalize<>(0.5, 0.5))
        .map(torch::data::transforms::Stack<>());

    auto test_dataset = torch::data::datasets::MNIST(
            "../data", 
            torch::data::datasets::MNIST::Mode::kTest
        )
        .map(torch::data::transforms::Normalize<>(0.5, 0.5))
        .map(torch::data::transforms::Stack<>());
    
    // instantiate mnist data loaders
    auto train_loader = torch::data::make_data_loader(
        std::move(train_dataset),
        torch::data::DataLoaderOptions()
            .batch_size(kBatchSize)
            .drop_last(true) // skip the last partial batch and prevent size mismatch
            .workers(1)
        );

    auto test_loader = torch::data::make_data_loader(
        std::move(test_dataset),
        torch::data::DataLoaderOptions()
            .batch_size(kBatchSize)
            .drop_last(true) // skip the last partial batch and prevent size mismatch
            .workers(1)
        );

    // ============================= mnist training ============================= //
    for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
        int64_t batch_index = 0;

        for (torch::data::Example<>& batch : *train_loader) {
            output_before = output_shared.clone();
        
            input_shared.copy_(batch.data.view({-1, input_size}));
            target_shared.copy_(batch.target);

            graph.replay();

            printf("loss: %.4f\n", loss.item<float>());
            // std::cout << "Output after: \n" << output_shared.slice(0,0,5) << std::endl;
            if (torch::equal(output_before, output_shared)) {
                std::cout << "Warning: output_shared is not updated!" << std::endl;
            } else {
                std::cout << "output_shared is updated." << std::endl;
            }
        }
    }
    // ============================= end training ============================= //

    net.eval();

    // ============================= mnist test ============================= //
    for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
        int64_t batch_index = 0;
        
        for (torch::data::Example<>& batch : *test_loader) {
            torch::Tensor output = net.forward(batch.data.view({-1, input_size}).to(device));
            auto pred = output.argmax(1);
            
            int64_t total = batch.data.size(0);
            int64_t correct = pred.eq(batch.target.to(device)).sum().item<int64_t>();
            std::cout << "Total: " << total << ", Correct: " << correct << std::endl;
        }
    }
    // ============================= end test ============================= //

    return 0;
}

I haven’t looked into the node manipulation part of the CUDA Graphs API, so I’m not sure if there are any additional requirements there—were you able to get a single node example (without merging) working either with PyTorch or with a standalone CUDA kernel?

I don’t understand what you mean by node manipulation. Do you mean the graph api instead of stream capture?
With the graph api (and I believe stream capture as well) the pointers have to be known at graph creation(/capture) time. So when torch returns an output tensor with a different internal address for the Tensor::data_ptr(), I seem to not be able to work with it, as I believe the cuda graph is just looking at the address that I give it in the graph (which will change in the next forward). Is there a way to do all inplace operations? Or use a preallocated tensor properly? I have tried many iterations of the code.

My updated code is as follows (for context). I am currently trying to copy from the inconsistent address to a consistent address within the stream capture to try and get it to be compatible with the printOutput function’s need for a constant address, however I believe on 1st graph invocation, the copyOutput fails as it is still trying to copy from the old data_ptr. There are other oddities in that the first copyOutput does not print.

#include <ATen/cuda/CUDAGraph.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/torch.h>

#include <iostream>

#include "helper_cuda.h"

void printTensorSize(torch::Tensor& tensor) {
    std::cout << "Size: [";
    for (size_t i = 0; i < tensor.sizes().size(); ++i) {
        std::cout << tensor.sizes()[i];
        if (i < tensor.sizes().size() - 1) {
            std::cout << ", ";
        }
    }
    std::cout << "]" << std::endl;
}

void printTensor(torch::Tensor& tensor) {
    if (tensor.dim() != 2) {
        std::cout << "Error: Tensor is not 2-dimensional." << std::endl;
        return;
    }

    for (int64_t i = 0; i < tensor.size(0); ++i) {
        for (int64_t j = 0; j < tensor.size(1); ++j) {
            std::cout << std::setw(10) << std::setprecision(4)
                      << tensor[i][j].item<float>() << " ";
        }
        std::cout << std::endl;
    }
    std::cout << "Address of the data: " << tensor.data_ptr() << std::endl;
}

struct Net : torch::nn::Module {
    torch::nn::Linear linear1, linear2, linear3;
    torch::Tensor preallocated_output_tensor;
    Net(int64_t input, int64_t hidden1, int64_t hidden2, int64_t output, int num_av, float* output_arr, torch::TensorOptions& options)
        : linear1(
              register_module("linear1", torch::nn::Linear(input, hidden1))),
          linear2(
              register_module("linear2", torch::nn::Linear(hidden1, hidden2))),
          linear3(
              register_module("linear3", torch::nn::Linear(hidden2, output))) {
                preallocated_output_tensor = torch::from_blob(output_arr, {num_av, output}, options);
              }

    // TODO: make in place version
    torch::Tensor forward(torch::Tensor x) {
        x = linear1->forward(x);
        torch::relu_(x);

        x = linear2->forward(x);
        torch::relu_(x);

        x = linear3->forward(x);

        auto batch_size = x.size(0);
        x = torch::arange(0, 4, torch::dtype(torch::kFloat32)).repeat({batch_size, 1});
        preallocated_output_tensor = x;
        return x;
    }
};

__global__ void copyOutput(int num_avs,float4* input, float4* output){
    const int tid = threadIdx.x + blockIdx.x * (blockDim.x);
    // auto f4_input = (float4*) input;
    printf("copying: 0x%p to 0x%p\n", input, output);
    if(tid < num_avs){
        output[tid].x=input[tid].x;
        output[tid].y=input[tid].y;
        output[tid].z=input[tid].z;
        output[tid].w=input[tid].w;
    }
}

__global__ void printOutput(int num_avs,float4 * output){
    const int tid = threadIdx.x + blockIdx.x * (blockDim.x);
    if(tid == 0){
        printf("cudaptr: 0x%p\n", output);
    }
    if(tid < num_avs){
        printf("tid: %d, Agent Output: %f %f %f %f\n ", tid, 
            output[tid].x,
            output[tid].y,
            output[tid].z,
            output[tid].w);
    }
}

__global__ void generateData(int num_avs, int set,float4* input){
    const int tid = threadIdx.x + blockIdx.x * (blockDim.x);
    printf("generating: 0x%p\n", input);
    if(tid < num_avs){
        input[tid].x = set;
        input[tid].y = set;
        input[tid].z = set;
        input[tid].w = set;
    }
}

int main() {
    torch::Device device(torch::kCUDA);
    int num_av = 10;
    int input_size = 40; // must be divisible by 4
    int output_size = 4; // must be divisible by 4
    float* input_arr, *output_arr;

    // allocate input and output arrays for underlying data in Tensors
    cudaError_t cudaStatus = cudaMallocManaged(&input_arr, input_size*num_av*sizeof(float));
    cudaMemset(input_arr, 0, input_size*num_av*sizeof(float));
    cudaStatus = cudaMallocManaged(&output_arr, output_size*num_av*sizeof(float));
    cudaMemset(output_arr, 0, output_size*num_av*sizeof(float));

    // allocate input and output tensors based on constant address internal arrays
    auto options = torch::TensorOptions().dtype(torch::kFloat32).device(device);
    torch::Tensor input_tensor = torch::from_blob(input_arr, {num_av, input_size}, options);
    torch::Tensor output_tensor = torch::from_blob(output_arr, {num_av, output_size}, options);
    printTensor(output_tensor);

    // create network
    Net model(input_size, 120, 20, output_size, num_av, output_arr, options);
    model.to(device);

    // warm start
    for (int i = 0; i < 1; i++) {
		output_tensor = model.forward(input_tensor);
	}

    // capture forward
    at::cuda::CUDAStream torchStream = at::cuda::getStreamFromPool(true);
    at::cuda::setCurrentCUDAStream(torchStream);
    at::cuda::CUDAGraph forward_graph;
	forward_graph.capture_begin();
    output_tensor = model.forward(input_tensor);
    copyOutput<<<dim3(1,1,1),dim3(num_av,1,1), 0, torchStream>>>(num_av, (float4*)output_tensor.data_ptr(), (float4*)output_arr);
    forward_graph.capture_end();
    cudaDeviceSynchronize();

    void* kernelArgs2[2];
	void* kernelArgs3[3];
    cudaGraph_t graph;
    cudaGraphCreate(&graph, 0);
    cudaGraphNode_t model_forward_node, set_input_node, print_output_node;
    cudaKernelNodeParams kernelNodeParams;
    int set_num = 1;

	kernelNodeParams.func = (void*) generateData;
	kernelNodeParams.gridDim = dim3(1, 1, 1);
	kernelNodeParams.blockDim = dim3(num_av, 1, 1);
	kernelNodeParams.sharedMemBytes = 0;
    kernelArgs3[0] = (void*)&num_av;
	kernelArgs3[1] = (void*)&set_num;
	kernelArgs3[2] = (void*)&input_arr;
	kernelNodeParams.kernelParams = kernelArgs3;
	kernelNodeParams.extra = NULL;
	cudaGraphAddKernelNode(&set_input_node, graph,nullptr,0, &kernelNodeParams);

    checkCudaErrors(cudaGraphAddChildGraphNode(&model_forward_node, graph, &set_input_node, 1, forward_graph.graph_));

	kernelNodeParams.func = (void*) printOutput;
	kernelNodeParams.gridDim = dim3(1, 1, 1);
	kernelNodeParams.blockDim = dim3(num_av, 1, 1);
	kernelNodeParams.sharedMemBytes = 0;
    kernelArgs2[0] = (void*)&num_av;
	kernelArgs2[1] = (void*)&output_arr;
	kernelNodeParams.kernelParams = kernelArgs2;
	kernelNodeParams.extra = NULL;
	cudaGraphAddKernelNode(&print_output_node, graph,&model_forward_node,1, &kernelNodeParams);

    // Launch the graph
    cudaGraphExec_t graphExec;
    cudaGraphInstantiate(&graphExec, graph, NULL, NULL, 0);
    cudaGraphLaunch(graphExec, 0);
    cudaDeviceSynchronize();
    // printTensor(output_tensor);

    // Launch the graph a second time
    cudaGraphLaunch(graphExec, 0);
    cudaDeviceSynchronize();
    // printTensor(output_tensor);

    // Cleanup
    cudaGraphExecDestroy(graphExec);
    cudaGraphDestroy(graph);
    cudaFree(input_arr);
    cudaFree(output_arr);

    return 0;
}

Yes, I mean explicit adding of nodes vs. graph capture. In the graph capture scenario, allocations done by PyTorch happen on a private memory pool to ensure that they are not recycled by the caching allocator after the graph capture. Without the private memory pool if you had any new allocations (e.g., for new inputs) the memory addresses could go stale. You might get lucky with more warmup iterations? A litmus test for this could be to see if any of your visible tensors have changing data_ptr()s, if they are changing, it could mean that the addresses would be going stale without a private memory pool.

Maybe this defeats the purpose of what you are trying to achieve, but I wonder if your approach could work with nodes that were produced by separate graph captures.

I have changed the code to warm start more and use graph capture and the torch api in a more traditional manner. However the printOutput is still not working. I believe that Tensor::data_ptr() might be a cpu address space pointer. Is it possible to get the GPU address space pointer and check if it is consistent?

Currently printOutput is still all zero. This might be because I am printing the underlying first allocation for the ::from_blob. If I try to print from .data_ptr() the graph replay fails with an Illegal Address. Which makes me think it is one of two potential problems.

  1. The allocator is giving me different data pointers.
  2. The ::data_ptr function is not the address I am looking for.

Note: when I print ::data_ptr during warmstart. That address still changes every iteration. Would I have to change the internal allocator? How else would I integrate torch inference into a cuda application?

Updated Code:

#include <ATen/cuda/CUDAGraph.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/torch.h>

#include <iostream>

#include "helper_cuda.h"

void printTensorSize(torch::Tensor& tensor) {
    std::cout << "Size: [";
    for (size_t i = 0; i < tensor.sizes().size(); ++i) {
        std::cout << tensor.sizes()[i];
        if (i < tensor.sizes().size() - 1) {
            std::cout << ", ";
        }
    }
    std::cout << "]" << std::endl;
}

void printTensor(torch::Tensor& tensor) {
    if (tensor.dim() != 2) {
        std::cout << "Error: Tensor is not 2-dimensional." << std::endl;
        return;
    }

    for (int64_t i = 0; i < tensor.size(0); ++i) {
        for (int64_t j = 0; j < tensor.size(1); ++j) {
            std::cout << std::setw(10) << std::setprecision(4)
                      << tensor[i][j].item<float>() << " ";
        }
        std::cout << std::endl;
    }
    std::cout << "Address of the data: " << tensor.data_ptr() << std::endl;
}

struct Net : torch::nn::Module {
    torch::nn::Linear linear1, linear2, linear3;
    Net(int64_t input, int64_t hidden1, int64_t hidden2, int64_t output, int num_av, float* output_arr, torch::TensorOptions& options)
        : linear1(
              register_module("linear1", torch::nn::Linear(input, hidden1))),
          linear2(
              register_module("linear2", torch::nn::Linear(hidden1, hidden2))),
          linear3(
              register_module("linear3", torch::nn::Linear(hidden2, output))) {}

    // TODO: make in place version
    torch::Tensor forward(torch::Tensor x) {
        x = linear1->forward(x);
        torch::relu_(x);

        x = linear2->forward(x);
        torch::relu_(x);

        x = linear3->forward(x);

        auto batch_size = x.size(0);
        x = torch::arange(0, 4, torch::dtype(torch::kFloat32)).repeat({batch_size, 1});
        return x;
    }
};

__global__ void copyOutput(int num_avs,float4* input, float4* output){
    const int tid = threadIdx.x + blockIdx.x * (blockDim.x);
    // auto f4_input = (float4*) input;
    printf("copying: %p to 0x%p\n", input, output);
    if(tid < num_avs){
        output[tid].x=input[tid].x;
        output[tid].y=input[tid].y;
        output[tid].z=input[tid].z;
        output[tid].w=input[tid].w;
    }
}

__global__ void printOutput(int num_avs,float4 * output){
    const int tid = threadIdx.x + blockIdx.x * (blockDim.x);
    if(tid == 0){
        printf("cudaptr: %p\n", output);
    }
    if(tid < num_avs){
        printf("tid: %d, Agent Output: %f %f %f %f\n ", tid,
            output[tid].x,
            output[tid].y,
            output[tid].z,
            output[tid].w);
    }
}

__global__ void generateData(int limit, int set,float* input){
    const int tid = threadIdx.x + blockIdx.x * (blockDim.x);
    if(tid == 0){
        printf("generating: %p\n", input);
    }
    if(tid < limit){
        input[tid] = set;
    }
}

int main() {
    torch::Device device(torch::kCUDA);
    int num_av = 10;
    int input_size = 40; // must be divisible by 4
    int output_size = 4; // must be divisible by 4
    float* input_arr, *output_arr;

    // allocate input and output arrays for underlying data in Tensors
    cudaError_t cudaStatus = cudaMallocManaged(&input_arr, input_size*num_av*sizeof(float));
    cudaMemset(input_arr, 0, input_size*num_av*sizeof(float));
    cudaStatus = cudaMallocManaged(&output_arr, output_size*num_av*sizeof(float));
    cudaMemset(output_arr, 0, output_size*num_av*sizeof(float));

    // allocate input and output tensors based on constant address internal arrays
    auto options = torch::TensorOptions().dtype(torch::kFloat32).device(device);
    torch::Tensor input_tensor = torch::from_blob(input_arr, {num_av, input_size}, options);
    torch::Tensor output_tensor = torch::from_blob(output_arr, {num_av, output_size}, options);

    // create network
    Net model(input_size, 120, 20, output_size, num_av, output_arr, options);
    model.to(device);

    // warm start
    at::cuda::CUDAStream torchStream = at::cuda::getStreamFromPool(true);
    at::cuda::setCurrentCUDAStream(torchStream);
    // torch::Tensor _;
    for (int i = 0; i < 10; i++) {
        generateData<<<dim3(1,1,1), dim3(input_size,1,1), 0, torchStream>>>(input_size, 1,input_arr);
		output_tensor = model.forward(input_tensor);
        std::cout <<"WARM: " <<  output_tensor.data_ptr() << std::endl;
        printOutput<<<dim3(1,1,1), dim3(num_av,1,1), 0, torchStream>>>(num_av, (float4*)output_arr);
	}
    std::cout << "Warm Start Finished" <<std::endl;

    // capture forward
    at::cuda::CUDAGraph forward_graph;
	forward_graph.capture_begin();
        generateData<<<dim3(1,1,1), dim3(input_size,1,1), 0, torchStream>>>(input_size, 1, input_arr);
		output_tensor = model.forward(input_tensor);
        printOutput<<<dim3(1,1,1), dim3(num_av,1,1), 0, torchStream>>>(num_av, (float4*)output_arr);
    forward_graph.capture_end();
    cudaDeviceSynchronize();

    std::cout << "Graph Capture Finished" <<std::endl;

    forward_graph.replay();
    cudaDeviceSynchronize();
    forward_graph.replay();
    cudaDeviceSynchronize();
    return 0;
}

Update here is a stream capture example that works. I might try to get the graph example working as well. I preallocated and made sure each tensor was on the GPU. I forgot that forward is a host function, and I cannot explicitly copy with GPU pointers, so I had forward invoke a kernel and it seemed to work.

#include <ATen/cuda/CUDAGraph.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/torch.h>

#include <iostream>

#include "helper_cuda.h"

void printTensorSize(torch::Tensor& tensor) {
    std::cout << "Size: [";
    for (size_t i = 0; i < tensor.sizes().size(); ++i) {
        std::cout << tensor.sizes()[i];
        if (i < tensor.sizes().size() - 1) {
            std::cout << ", ";
        }
    }
    std::cout << "]" << std::endl;
}

void printTensor(torch::Tensor& tensor) {
    if (tensor.dim() != 2) {
        std::cout << "Error: Tensor is not 2-dimensional." << std::endl;
        return;
    }

    for (int64_t i = 0; i < tensor.size(0); ++i) {
        for (int64_t j = 0; j < tensor.size(1); ++j) {
            std::cout << std::setw(10) << std::setprecision(4)
                      << tensor[i][j].item<float>() << " ";
        }
        std::cout << std::endl;
    }
    std::cout << "Address of the data: " << tensor.data_ptr() << std::endl;
}

__global__ void copyOutput(int num_avs,float4* input, float4* output){
    const int tid = threadIdx.x + blockIdx.x * (blockDim.x);
    // auto f4_input = (float4*) input;
    printf("copying: %p to 0x%p\n", input, output);
    if(tid < num_avs){
        output[tid].x=input[tid].x;
        output[tid].y=input[tid].y;
        output[tid].z=input[tid].z;
        output[tid].w=input[tid].w;
    }
}

struct Net : torch::nn::Module {
    torch::nn::Linear linear1, linear2, linear3;
    torch::Tensor preallocated_output, preallocated_1, preallocated_2, preallocated_3, range;
    Net(int64_t input, int64_t hidden1, int64_t hidden2, int64_t output, int num_av, float* output_arr, torch::TensorOptions& options)
        : linear1(
              register_module("linear1", torch::nn::Linear(input, hidden1))),
          linear2(
              register_module("linear2", torch::nn::Linear(hidden1, hidden2))),
          linear3(
              register_module("linear3", torch::nn::Linear(hidden2, output))) {
                preallocated_output = torch::from_blob(output_arr, {num_av, output}, options);
                preallocated_1 = torch::zeros({num_av, hidden1}, options); 
                preallocated_2 = torch::zeros({num_av, hidden2}, options); 
                preallocated_3 = torch::zeros({num_av, output}, options); 
                range = torch::arange(0,output, options).repeat({10,1});
              }

    // TODO: make in place version
    torch::Tensor forward(torch::Tensor x) {
        preallocated_1 = linear1->forward(x);
        torch::relu_(preallocated_1);

        preallocated_2 = linear2->forward(preallocated_1);
        torch::relu_(preallocated_2);

        preallocated_3 = linear3->forward(preallocated_2);

        auto batch_size = preallocated_3.size(0);
        auto options = torch::TensorOptions().dtype(torch::kFloat32);
        copyOutput<<<dim3(1,1,1),dim3(10,1,1)>>>(10, (float4*)(range.data_ptr()), (float4*)(preallocated_output.data_ptr())); 
        return preallocated_3;
    }
};



__global__ void printOutput(int num_avs,float4 * output){
    const int tid = threadIdx.x + blockIdx.x * (blockDim.x);
    if(tid == 0){
        printf("cudaptr: %p\n", output);
    }
    if(tid < num_avs){
        printf("tid: %d, Agent Output: %f %f %f %f\n ", tid,
            output[tid].x,
            output[tid].y,
            output[tid].z,
            output[tid].w);
    }
}

__global__ void generateData(int limit, int set,float* input){
    const int tid = threadIdx.x + blockIdx.x * (blockDim.x);
    if(tid == 0){
        printf("generating: %p\n", input);
    }
    if(tid < limit){
        input[tid] = set;
    }
}

int main() {
    torch::Device device(torch::kCUDA);
    int num_av = 10;
    int input_size = 40; // must be divisible by 4
    int output_size = 4; // must be divisible by 4
    float* input_arr, *output_arr;

    // allocate input and output arrays for underlying data in Tensors
    cudaError_t cudaStatus = cudaMallocManaged(&input_arr, input_size*num_av*sizeof(float));
    cudaMemset(input_arr, 0, input_size*num_av*sizeof(float));
    cudaStatus = cudaMallocManaged(&output_arr, output_size*num_av*sizeof(float));
    cudaMemset(output_arr, 0, output_size*num_av*sizeof(float));

    // allocate input and output tensors based on constant address internal arrays
    auto options = torch::TensorOptions().dtype(torch::kFloat32).device(device);
    torch::Tensor input_tensor = torch::from_blob(input_arr, {num_av, input_size}, options);

    // create network
    Net model(input_size, 120, 20, output_size, num_av, output_arr, options);
    model.to(device);

    // warm start
    at::cuda::CUDAStream torchStream = at::cuda::getStreamFromPool(true);
    at::cuda::setCurrentCUDAStream(torchStream);
    torch::Tensor _;
    for (int i = 0; i < 10; i++) {
        generateData<<<dim3(1,1,1), dim3(input_size,1,1), 0, torchStream>>>(input_size, 2,input_arr);
        cudaDeviceSynchronize();
		_ = model.forward(input_tensor);
        cudaDeviceSynchronize();
        // std::cout <<"WARM: " <<  model.preallocated_output.data_ptr() << std::endl;
        printTensor(model.preallocated_3);
        printTensor(model.preallocated_output);
        printOutput<<<dim3(1,1,1), dim3(num_av,1,1), 0, torchStream>>>(num_av, (float4*)output_arr);
        cudaDeviceSynchronize();
	}
    std::cout << "Warm Start Finished" <<std::endl;

    // capture forward
    at::cuda::CUDAGraph forward_graph;
	forward_graph.capture_begin();
        generateData<<<dim3(1,1,1), dim3(input_size,1,1), 0, torchStream>>>(input_size, 3, input_arr);
		_ = model.forward(input_tensor);
        printOutput<<<dim3(1,1,1), dim3(num_av,1,1), 0, torchStream>>>(num_av, (float4*)output_arr);
    forward_graph.capture_end();
    cudaDeviceSynchronize();

    std::cout << "Graph Capture Finished" <<std::endl;

    forward_graph.replay();
    cudaDeviceSynchronize();
    forward_graph.replay();
    cudaDeviceSynchronize();
    return 0;
}