Nvidia N-body executing CUDA kernel with pytorch

Hello, I recently encountered a Fast N-body simulation setup with CUDA on the official Nvidia developer website.

I have Nvidia RTX 2060 running under Debian Bookworm distribution with a driver version 535.183.01. I have the nvidia-cuda-toolkit 11.8.89~11.8.0-5~deb12u1 installed and the Cuda compilation tools, release 11.8, V11.8.89. The CUDA compiler driver is:

nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022

I installed pytorch version 2.5.1 with Anaconda installer. The result from from conda list pytorch:
pytorch 2.5.1 py3.10_cuda11.8_cudnn9.1.0_0 pytorch pytorch-cuda 11.8 h7e8668a_6 pytorch pytorch-mutex 1.0 cuda pytorch

Pytorch recognizes the .bashrc correct path for CUDA libraries:
print(torch.utils.cpp_extension.CUDA_HOME)
/usr/lib/nvidia-cuda-toolkit

torch.cuda.is_available()
Out[19]: True

I am struggling to achieve even to compile the simplest CUDA kernel using pytorch functionality and the ninja always raises an error during compilation step. I would like to maintain purely pytorch implementation including gpu interaction, rather than switching temporarily to pycuda which I would like completely replaced, but works.

import torch
from torch.utils.cpp_extension import load_inline


# CUDA kernel
cuda_kernel = """
extern "C" __global__ void calculate_forces(float4 *globalX, float3 *globalA, int N, int p) {
    extern __shared__ float4 shPosition[];
    float4 myPosition;
    float3 acc = {0.0f, 0.0f, 0.0f};
    int gtid = blockIdx.x * blockDim.x + threadIdx.x;

    if (gtid >= N) return;

    myPosition = globalX[gtid];

    for (int i = 0, tile = 0; i < N; i += p, tile++) {
        int idx = tile * blockDim.x + threadIdx.x;

        if (idx < N) {
            shPosition[threadIdx.x] = globalX[idx];
        }
        __syncthreads();

        for (int j = 0; j < p && (tile * blockDim.x + j) < N; j++) {
            float4 otherPosition = shPosition[j];
            float3 r;
            r.x = otherPosition.x - myPosition.x;
            r.y = otherPosition.y - myPosition.y;
            r.z = otherPosition.z - myPosition.z;

            float distSqr = r.x * r.x + r.y * r.y + r.z * r.z + 1e-10f;
            float invDist = rsqrtf(distSqr);
            float invDistCube = invDist * invDist * invDist;

            float s = otherPosition.w * invDistCube;
            acc.x += s * r.x;
            acc.y += s * r.y;
            acc.z += s * r.z;
        }
        __syncthreads();
    }

    globalA[gtid] = acc; // Output only the accelerations as a float3
}
"""

# Compile the CUDA kernel with PyTorch's `load_inline`
module = load_inline(name="cuda_forces", cpp_sources="", 
                     cuda_sources=cuda_kernel, functions=["calculate_forces"])

# Parameters
N = 1000  # Number of particles
p = 128   # Threads per block
block_size = 128  # Threads per block
grid_size = (N + block_size - 1) // block_size  # Number of blocks

# Generate random positions for particles (x, y, z, mass)
# Use PyTorch tensors instead of NumPy arrays
positions = torch.rand((N, 4), device='cuda', dtype=torch.float32)  # (x, y, z, mass)
accelerations = torch.zeros_like(positions, device='cuda', dtype=torch.float32)  # Initialize accelerations

# Convert positions to float4 format for CUDA
positions_float4 = positions.view(-1).contiguous()
accelerations_float4 = accelerations.view(-1).contiguous()

# Get raw pointers to PyTorch tensors
positions_ptr = positions_float4.data_ptr()
accelerations_ptr = accelerations_float4.data_ptr()

# Shared memory size (float4 is 4 floats = 16 bytes)
shared_memory_size = block_size * 16  # Shared memory per block

# Execute the kernel
module.calculate_forces(
    torch.cuda.IntTensor([positions_ptr]),
    torch.cuda.IntTensor([accelerations_ptr]),
    grid=(grid_size, 1, 1),
    block=(block_size, 1, 1),
    shared_memory=shared_memory_size
)

# Accelerations are now computed in the accelerations tensor
print("Computed accelerations:\n", accelerations)

The error logs are

%runfile /home/student/nvidia/nvidia.py --wdir
/home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/utils/cpp_extension.py:1964: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
---------------------------------------------------------------------------
CalledProcessError                        Traceback (most recent call last)
File ~/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/utils/cpp_extension.py:2104, in _run_ninja_build(build_directory, verbose, error_prefix)
   2103     stdout_fileno = 1
-> 2104     subprocess.run(
   2105         command,
   2106         stdout=stdout_fileno if verbose else subprocess.PIPE,
   2107         stderr=subprocess.STDOUT,
   2108         cwd=build_directory,
   2109         check=True,
   2110         env=env)
   2111 except subprocess.CalledProcessError as e:
   2112     # Python 2 and 3 compatible way of getting the error object.

File ~/anaconda3/envs/pytorch/lib/python3.10/subprocess.py:526, in run(input, capture_output, timeout, check, *popenargs, **kwargs)
    525     if check and retcode:
--> 526         raise CalledProcessError(retcode, process.args,
    527                                  output=stdout, stderr=stderr)
    528 return CompletedProcess(process.args, retcode, stdout, stderr)

CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
File ~/anaconda3/envs/pytorch/lib/python3.10/site-packages/spyder_kernels/customize/utils.py:209, in exec_encapsulate_locals(code_ast, globals, locals, exec_fun, filename)
    207     if filename is None:
    208         filename = "<stdin>"
--> 209     exec_fun(compile(code_ast, filename, "exec"), globals, None)
    210 finally:
    211     if use_locals_hack:
    212         # Cleanup code

File /home/student/nvidia/nvidia.py:57
     14 cuda_kernel = """
     15 extern "C" __global__ void calculate_forces(float4 *globalX, float3 *globalA, int N, int p) {
     16     extern __shared__ float4 shPosition[];
   (...)
     53 }
     54 """
     56 # Compile the CUDA kernel with PyTorch's `load_inline`
---> 57 module = load_inline(name="cuda_forces", cpp_sources="", 
     58                      cuda_sources=cuda_kernel, functions=["calculate_forces"])
     60 # Parameters
     61 N = 1000  # Number of particles

File ~/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/utils/cpp_extension.py:1646, in load_inline(name, cpp_sources, cuda_sources, functions, extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths, build_directory, verbose, with_cuda, is_python_module, with_pytorch_error_handling, keep_intermediates, use_pch)
   1642     _maybe_write(cuda_source_path, "\n".join(cuda_sources))
   1644     sources.append(cuda_source_path)
-> 1646 return _jit_compile(
   1647     name,
   1648     sources,
   1649     extra_cflags,
   1650     extra_cuda_cflags,
   1651     extra_ldflags,
   1652     extra_include_paths,
   1653     build_directory,
   1654     verbose,
   1655     with_cuda,
   1656     is_python_module,
   1657     is_standalone=False,
   1658     keep_intermediates=keep_intermediates)

File ~/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/utils/cpp_extension.py:1721, in _jit_compile(name, sources, extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths, build_directory, verbose, with_cuda, is_python_module, is_standalone, keep_intermediates)
   1717                 hipified_sources.add(hipify_result[s_abs].hipified_path if s_abs in hipify_result else s_abs)
   1719             sources = list(hipified_sources)
-> 1721         _write_ninja_file_and_build_library(
   1722             name=name,
   1723             sources=sources,
   1724             extra_cflags=extra_cflags or [],
   1725             extra_cuda_cflags=extra_cuda_cflags or [],
   1726             extra_ldflags=extra_ldflags or [],
   1727             extra_include_paths=extra_include_paths or [],
   1728             build_directory=build_directory,
   1729             verbose=verbose,
   1730             with_cuda=with_cuda,
   1731             is_standalone=is_standalone)
   1732 elif verbose:
   1733     print('No modifications detected for re-loaded extension '
   1734           f'module {name}, skipping build step...', file=sys.stderr)

File ~/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/utils/cpp_extension.py:1833, in _write_ninja_file_and_build_library(name, sources, extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths, build_directory, verbose, with_cuda, is_standalone)
   1831 if verbose:
   1832     print(f'Building extension module {name}...', file=sys.stderr)
-> 1833 _run_ninja_build(
   1834     build_directory,
   1835     verbose,
   1836     error_prefix=f"Error building extension '{name}'")

File ~/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/utils/cpp_extension.py:2120, in _run_ninja_build(build_directory, verbose, error_prefix)
   2118 if hasattr(error, 'output') and error.output:  # type: ignore[union-attr]
   2119     message += f": {error.output.decode(*SUBPROCESS_DECODE_ARGS)}"  # type: ignore[union-attr]
-> 2120 raise RuntimeError(message) from e

RuntimeError: Error building extension 'cuda_forces': [1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=cuda_forces -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include/TH -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include/THC -isystem /usr/lib/nvidia-cuda-toolkit/include -isystem /home/student/anaconda3/envs/pytorch/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /home/student/.cache/torch_extensions/py310_cu118/cuda_forces/main.cpp -o main.o 
FAILED: main.o 
c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=cuda_forces -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include/TH -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include/THC -isystem /usr/lib/nvidia-cuda-toolkit/include -isystem /home/student/anaconda3/envs/pytorch/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /home/student/.cache/torch_extensions/py310_cu118/cuda_forces/main.cpp -o main.o 
/home/student/.cache/torch_extensions/py310_cu118/cuda_forces/main.cpp: In function ‘void pybind11_init_cuda_forces(pybind11::module_&)’:
/home/student/.cache/torch_extensions/py310_cu118/cuda_forces/main.cpp:4:55: error: ‘calculate_forces’ was not declared in this scope
    4 | m.def("calculate_forces", torch::wrap_pybind_function(calculate_forces), "calculate_forces");
      |                                                       ^~~~~~~~~~~~~~~~
[2/3] /usr/lib/nvidia-cuda-toolkit/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -DTORCH_EXTENSION_NAME=cuda_forces -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include/TH -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include/THC -isystem /usr/lib/nvidia-cuda-toolkit/include -isystem /home/student/anaconda3/envs/pytorch/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_75,code=compute_75 -gencode=arch=compute_75,code=sm_75 --compiler-options '-fPIC' -std=c++17 -c /home/student/.cache/torch_extensions/py310_cu118/cuda_forces/cuda.cu -o cuda.cuda.o 
ninja: build stopped: subcommand failed.

Below is the script which uses pycuda and is able to execute CUDA kernel from NVIDIA website.

import pycuda.autoinit
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
import torch
import numpy as np

# CUDA kernel
cuda_kernel = """
extern "C" __global__ void calculate_forces(float4 *globalX, float3 *globalA, int N, int p) {
    extern __shared__ float4 shPosition[];
    float4 myPosition;
    float3 acc = {0.0f, 0.0f, 0.0f};
    int gtid = blockIdx.x * blockDim.x + threadIdx.x;

    if (gtid >= N) return;

    myPosition = globalX[gtid];

    for (int i = 0, tile = 0; i < N; i += p, tile++) {
        int idx = tile * blockDim.x + threadIdx.x;

        if (idx < N) {
            shPosition[threadIdx.x] = globalX[idx];
        }
        __syncthreads();

        for (int j = 0; j < p && (tile * blockDim.x + j) < N; j++) {
            float4 otherPosition = shPosition[j];
            float3 r;
            r.x = otherPosition.x - myPosition.x;
            r.y = otherPosition.y - myPosition.y;
            r.z = otherPosition.z - myPosition.z;

            float distSqr = r.x * r.x + r.y * r.y + r.z * r.z + 1e-10f;
            float invDist = rsqrtf(distSqr);
            float invDistCube = invDist * invDist * invDist;

            float s = otherPosition.w * invDistCube;
            acc.x += s * r.x;
            acc.y += s * r.y;
            acc.z += s * r.z;
        }
        __syncthreads();
    }

    globalA[gtid] = acc; // Output only the accelerations as a float3
}
"""

# Compile the CUDA kernel
mod = SourceModule(cuda_kernel)
calculate_forces = mod.get_function("calculate_forces")

def calculate_forces_with_pycuda(input_tensor: torch.Tensor) -> torch.Tensor:
    """
    Compute forces based on positions and masses using a PyCUDA kernel.

    Args:
        input_tensor (torch.Tensor): Input tensor of shape (N, 4) containing positions (x, y, z) and masses.

    Returns:
        torch.Tensor: Output tensor of shape (N, 3) containing the computed forces.
    """
    # Validate input
    assert input_tensor.shape[1] == 4, "Input tensor must have shape (N, 4)."
    assert input_tensor.is_cuda, "Input tensor must be on CUDA device."
    assert input_tensor.dtype == torch.float32, "Input tensor must have dtype torch.float32."

    # Number of particles
    N = input_tensor.size(0)

    # Create output tensor for accelerations
    output_tensor = torch.empty((N, 3), device="cuda", dtype=torch.float32)

    # Allocate device memory for input and output
    dev_positions = cuda.mem_alloc(input_tensor.nbytes)
    dev_accelerations = cuda.mem_alloc(output_tensor.nbytes)

    # Transfer data to device
    cuda.memcpy_htod(dev_positions, input_tensor.contiguous().cpu().numpy())

    # Define block and grid sizes
    threads_per_block = 256
    blocks_per_grid = (N + threads_per_block - 1) // threads_per_block

    # Run the kernel
    calculate_forces(
        dev_positions,
        dev_accelerations,
        np.int32(N),
        np.int32(threads_per_block),
        block=(threads_per_block, 1, 1),
        grid=(blocks_per_grid, 1, 1),
        shared=threads_per_block * 16  # Shared memory size (16 bytes per float4)
    )

    cuda.Context.synchronize()

    # Copy the results back to host and convert to torch tensor
    forces_np = np.empty((N, 3), dtype=np.float32)
    cuda.memcpy_dtoh(forces_np, dev_accelerations)

    # Convert numpy array back to torch tensor
    forces_tensor = torch.tensor(forces_np, device="cuda", dtype=torch.float32)

    return forces_tensor


# Example usage
if __name__ == "__main__":
    # Number of particles
    N = 1024

    # Create an input tensor with positions and masses
    input_tensor = torch.zeros((N, 4), device="cuda", dtype=torch.float32)
    input_tensor[:, :3] = torch.rand((N, 3), device="cuda")  # Random positions
    input_tensor[:, 3] = torch.rand((N,), device="cuda") + 0.1  # Random masses

    # Calculate forces
    forces = calculate_forces_with_pycuda(input_tensor)

    # Print results
    print("Forces:")
    print(forces)

Could you suggest how to get past the compilation error? I’m able to use torch.cuda() and execute computations on gpu. It’s ‘just’ how to correctly execute CUDA kernel with pytorch built in functionality.

To build a custom CUDA extension you need to install a CUDA toolkit locally and the compilation fails with:

So it seems method was not properly declared.