Hello, I recently encountered a Fast N-body simulation setup with CUDA on the official Nvidia developer website.
I have Nvidia RTX 2060
running under Debian Bookworm
distribution with a driver version 535.183.01
. I have the nvidia-cuda-toolkit 11.8.89~11.8.0-5~deb12u1
installed and the Cuda compilation tools, release 11.8, V11.8.89
. The CUDA compiler driver is:
nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
I installed pytorch version 2.5.1 with Anaconda installer. The result from from conda list pytorch
:
pytorch 2.5.1 py3.10_cuda11.8_cudnn9.1.0_0 pytorch pytorch-cuda 11.8 h7e8668a_6 pytorch pytorch-mutex 1.0 cuda pytorch
Pytorch recognizes the .bashrc correct path for CUDA libraries:
print(torch.utils.cpp_extension.CUDA_HOME)
/usr/lib/nvidia-cuda-toolkit
torch.cuda.is_available()
Out[19]: True
I am struggling to achieve even to compile the simplest CUDA kernel using pytorch functionality and the ninja
always raises an error during compilation step. I would like to maintain purely pytorch
implementation including gpu interaction, rather than switching temporarily to pycuda
which I would like completely replaced, but works.
import torch
from torch.utils.cpp_extension import load_inline
# CUDA kernel
cuda_kernel = """
extern "C" __global__ void calculate_forces(float4 *globalX, float3 *globalA, int N, int p) {
extern __shared__ float4 shPosition[];
float4 myPosition;
float3 acc = {0.0f, 0.0f, 0.0f};
int gtid = blockIdx.x * blockDim.x + threadIdx.x;
if (gtid >= N) return;
myPosition = globalX[gtid];
for (int i = 0, tile = 0; i < N; i += p, tile++) {
int idx = tile * blockDim.x + threadIdx.x;
if (idx < N) {
shPosition[threadIdx.x] = globalX[idx];
}
__syncthreads();
for (int j = 0; j < p && (tile * blockDim.x + j) < N; j++) {
float4 otherPosition = shPosition[j];
float3 r;
r.x = otherPosition.x - myPosition.x;
r.y = otherPosition.y - myPosition.y;
r.z = otherPosition.z - myPosition.z;
float distSqr = r.x * r.x + r.y * r.y + r.z * r.z + 1e-10f;
float invDist = rsqrtf(distSqr);
float invDistCube = invDist * invDist * invDist;
float s = otherPosition.w * invDistCube;
acc.x += s * r.x;
acc.y += s * r.y;
acc.z += s * r.z;
}
__syncthreads();
}
globalA[gtid] = acc; // Output only the accelerations as a float3
}
"""
# Compile the CUDA kernel with PyTorch's `load_inline`
module = load_inline(name="cuda_forces", cpp_sources="",
cuda_sources=cuda_kernel, functions=["calculate_forces"])
# Parameters
N = 1000 # Number of particles
p = 128 # Threads per block
block_size = 128 # Threads per block
grid_size = (N + block_size - 1) // block_size # Number of blocks
# Generate random positions for particles (x, y, z, mass)
# Use PyTorch tensors instead of NumPy arrays
positions = torch.rand((N, 4), device='cuda', dtype=torch.float32) # (x, y, z, mass)
accelerations = torch.zeros_like(positions, device='cuda', dtype=torch.float32) # Initialize accelerations
# Convert positions to float4 format for CUDA
positions_float4 = positions.view(-1).contiguous()
accelerations_float4 = accelerations.view(-1).contiguous()
# Get raw pointers to PyTorch tensors
positions_ptr = positions_float4.data_ptr()
accelerations_ptr = accelerations_float4.data_ptr()
# Shared memory size (float4 is 4 floats = 16 bytes)
shared_memory_size = block_size * 16 # Shared memory per block
# Execute the kernel
module.calculate_forces(
torch.cuda.IntTensor([positions_ptr]),
torch.cuda.IntTensor([accelerations_ptr]),
grid=(grid_size, 1, 1),
block=(block_size, 1, 1),
shared_memory=shared_memory_size
)
# Accelerations are now computed in the accelerations tensor
print("Computed accelerations:\n", accelerations)
The error logs are
%runfile /home/student/nvidia/nvidia.py --wdir
/home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/utils/cpp_extension.py:1964: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
---------------------------------------------------------------------------
CalledProcessError Traceback (most recent call last)
File ~/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/utils/cpp_extension.py:2104, in _run_ninja_build(build_directory, verbose, error_prefix)
2103 stdout_fileno = 1
-> 2104 subprocess.run(
2105 command,
2106 stdout=stdout_fileno if verbose else subprocess.PIPE,
2107 stderr=subprocess.STDOUT,
2108 cwd=build_directory,
2109 check=True,
2110 env=env)
2111 except subprocess.CalledProcessError as e:
2112 # Python 2 and 3 compatible way of getting the error object.
File ~/anaconda3/envs/pytorch/lib/python3.10/subprocess.py:526, in run(input, capture_output, timeout, check, *popenargs, **kwargs)
525 if check and retcode:
--> 526 raise CalledProcessError(retcode, process.args,
527 output=stdout, stderr=stderr)
528 return CompletedProcess(process.args, retcode, stdout, stderr)
CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
File ~/anaconda3/envs/pytorch/lib/python3.10/site-packages/spyder_kernels/customize/utils.py:209, in exec_encapsulate_locals(code_ast, globals, locals, exec_fun, filename)
207 if filename is None:
208 filename = "<stdin>"
--> 209 exec_fun(compile(code_ast, filename, "exec"), globals, None)
210 finally:
211 if use_locals_hack:
212 # Cleanup code
File /home/student/nvidia/nvidia.py:57
14 cuda_kernel = """
15 extern "C" __global__ void calculate_forces(float4 *globalX, float3 *globalA, int N, int p) {
16 extern __shared__ float4 shPosition[];
(...)
53 }
54 """
56 # Compile the CUDA kernel with PyTorch's `load_inline`
---> 57 module = load_inline(name="cuda_forces", cpp_sources="",
58 cuda_sources=cuda_kernel, functions=["calculate_forces"])
60 # Parameters
61 N = 1000 # Number of particles
File ~/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/utils/cpp_extension.py:1646, in load_inline(name, cpp_sources, cuda_sources, functions, extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths, build_directory, verbose, with_cuda, is_python_module, with_pytorch_error_handling, keep_intermediates, use_pch)
1642 _maybe_write(cuda_source_path, "\n".join(cuda_sources))
1644 sources.append(cuda_source_path)
-> 1646 return _jit_compile(
1647 name,
1648 sources,
1649 extra_cflags,
1650 extra_cuda_cflags,
1651 extra_ldflags,
1652 extra_include_paths,
1653 build_directory,
1654 verbose,
1655 with_cuda,
1656 is_python_module,
1657 is_standalone=False,
1658 keep_intermediates=keep_intermediates)
File ~/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/utils/cpp_extension.py:1721, in _jit_compile(name, sources, extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths, build_directory, verbose, with_cuda, is_python_module, is_standalone, keep_intermediates)
1717 hipified_sources.add(hipify_result[s_abs].hipified_path if s_abs in hipify_result else s_abs)
1719 sources = list(hipified_sources)
-> 1721 _write_ninja_file_and_build_library(
1722 name=name,
1723 sources=sources,
1724 extra_cflags=extra_cflags or [],
1725 extra_cuda_cflags=extra_cuda_cflags or [],
1726 extra_ldflags=extra_ldflags or [],
1727 extra_include_paths=extra_include_paths or [],
1728 build_directory=build_directory,
1729 verbose=verbose,
1730 with_cuda=with_cuda,
1731 is_standalone=is_standalone)
1732 elif verbose:
1733 print('No modifications detected for re-loaded extension '
1734 f'module {name}, skipping build step...', file=sys.stderr)
File ~/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/utils/cpp_extension.py:1833, in _write_ninja_file_and_build_library(name, sources, extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths, build_directory, verbose, with_cuda, is_standalone)
1831 if verbose:
1832 print(f'Building extension module {name}...', file=sys.stderr)
-> 1833 _run_ninja_build(
1834 build_directory,
1835 verbose,
1836 error_prefix=f"Error building extension '{name}'")
File ~/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/utils/cpp_extension.py:2120, in _run_ninja_build(build_directory, verbose, error_prefix)
2118 if hasattr(error, 'output') and error.output: # type: ignore[union-attr]
2119 message += f": {error.output.decode(*SUBPROCESS_DECODE_ARGS)}" # type: ignore[union-attr]
-> 2120 raise RuntimeError(message) from e
RuntimeError: Error building extension 'cuda_forces': [1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=cuda_forces -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include/TH -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include/THC -isystem /usr/lib/nvidia-cuda-toolkit/include -isystem /home/student/anaconda3/envs/pytorch/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /home/student/.cache/torch_extensions/py310_cu118/cuda_forces/main.cpp -o main.o
FAILED: main.o
c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=cuda_forces -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include/TH -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include/THC -isystem /usr/lib/nvidia-cuda-toolkit/include -isystem /home/student/anaconda3/envs/pytorch/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /home/student/.cache/torch_extensions/py310_cu118/cuda_forces/main.cpp -o main.o
/home/student/.cache/torch_extensions/py310_cu118/cuda_forces/main.cpp: In function ‘void pybind11_init_cuda_forces(pybind11::module_&)’:
/home/student/.cache/torch_extensions/py310_cu118/cuda_forces/main.cpp:4:55: error: ‘calculate_forces’ was not declared in this scope
4 | m.def("calculate_forces", torch::wrap_pybind_function(calculate_forces), "calculate_forces");
| ^~~~~~~~~~~~~~~~
[2/3] /usr/lib/nvidia-cuda-toolkit/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -DTORCH_EXTENSION_NAME=cuda_forces -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include/TH -isystem /home/student/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/include/THC -isystem /usr/lib/nvidia-cuda-toolkit/include -isystem /home/student/anaconda3/envs/pytorch/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_75,code=compute_75 -gencode=arch=compute_75,code=sm_75 --compiler-options '-fPIC' -std=c++17 -c /home/student/.cache/torch_extensions/py310_cu118/cuda_forces/cuda.cu -o cuda.cuda.o
ninja: build stopped: subcommand failed.
Below is the script which uses pycuda
and is able to execute CUDA kernel from NVIDIA website.
import pycuda.autoinit
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
import torch
import numpy as np
# CUDA kernel
cuda_kernel = """
extern "C" __global__ void calculate_forces(float4 *globalX, float3 *globalA, int N, int p) {
extern __shared__ float4 shPosition[];
float4 myPosition;
float3 acc = {0.0f, 0.0f, 0.0f};
int gtid = blockIdx.x * blockDim.x + threadIdx.x;
if (gtid >= N) return;
myPosition = globalX[gtid];
for (int i = 0, tile = 0; i < N; i += p, tile++) {
int idx = tile * blockDim.x + threadIdx.x;
if (idx < N) {
shPosition[threadIdx.x] = globalX[idx];
}
__syncthreads();
for (int j = 0; j < p && (tile * blockDim.x + j) < N; j++) {
float4 otherPosition = shPosition[j];
float3 r;
r.x = otherPosition.x - myPosition.x;
r.y = otherPosition.y - myPosition.y;
r.z = otherPosition.z - myPosition.z;
float distSqr = r.x * r.x + r.y * r.y + r.z * r.z + 1e-10f;
float invDist = rsqrtf(distSqr);
float invDistCube = invDist * invDist * invDist;
float s = otherPosition.w * invDistCube;
acc.x += s * r.x;
acc.y += s * r.y;
acc.z += s * r.z;
}
__syncthreads();
}
globalA[gtid] = acc; // Output only the accelerations as a float3
}
"""
# Compile the CUDA kernel
mod = SourceModule(cuda_kernel)
calculate_forces = mod.get_function("calculate_forces")
def calculate_forces_with_pycuda(input_tensor: torch.Tensor) -> torch.Tensor:
"""
Compute forces based on positions and masses using a PyCUDA kernel.
Args:
input_tensor (torch.Tensor): Input tensor of shape (N, 4) containing positions (x, y, z) and masses.
Returns:
torch.Tensor: Output tensor of shape (N, 3) containing the computed forces.
"""
# Validate input
assert input_tensor.shape[1] == 4, "Input tensor must have shape (N, 4)."
assert input_tensor.is_cuda, "Input tensor must be on CUDA device."
assert input_tensor.dtype == torch.float32, "Input tensor must have dtype torch.float32."
# Number of particles
N = input_tensor.size(0)
# Create output tensor for accelerations
output_tensor = torch.empty((N, 3), device="cuda", dtype=torch.float32)
# Allocate device memory for input and output
dev_positions = cuda.mem_alloc(input_tensor.nbytes)
dev_accelerations = cuda.mem_alloc(output_tensor.nbytes)
# Transfer data to device
cuda.memcpy_htod(dev_positions, input_tensor.contiguous().cpu().numpy())
# Define block and grid sizes
threads_per_block = 256
blocks_per_grid = (N + threads_per_block - 1) // threads_per_block
# Run the kernel
calculate_forces(
dev_positions,
dev_accelerations,
np.int32(N),
np.int32(threads_per_block),
block=(threads_per_block, 1, 1),
grid=(blocks_per_grid, 1, 1),
shared=threads_per_block * 16 # Shared memory size (16 bytes per float4)
)
cuda.Context.synchronize()
# Copy the results back to host and convert to torch tensor
forces_np = np.empty((N, 3), dtype=np.float32)
cuda.memcpy_dtoh(forces_np, dev_accelerations)
# Convert numpy array back to torch tensor
forces_tensor = torch.tensor(forces_np, device="cuda", dtype=torch.float32)
return forces_tensor
# Example usage
if __name__ == "__main__":
# Number of particles
N = 1024
# Create an input tensor with positions and masses
input_tensor = torch.zeros((N, 4), device="cuda", dtype=torch.float32)
input_tensor[:, :3] = torch.rand((N, 3), device="cuda") # Random positions
input_tensor[:, 3] = torch.rand((N,), device="cuda") + 0.1 # Random masses
# Calculate forces
forces = calculate_forces_with_pycuda(input_tensor)
# Print results
print("Forces:")
print(forces)
Could you suggest how to get past the compilation error? I’m able to use torch.cuda()
and execute computations on gpu. It’s ‘just’ how to correctly execute CUDA kernel with pytorch built in functionality.