To speed up throughput of my NeRF-like model for 3D scientific data, I wrote some custom CUDA kernels for the encoding step (the decoding step is just a small MLP).
This worked fine and dandy, improving my encoding throughput.
Native PyTorch:112 mil points/sec, 8GB VRAM
CUDA kernels: 190 mil points/sec, 4GB VRAM.
Then I thought about AMP and how that could probably speed up both significantly. After re-writing my C++/kernels to use e.g. packed_accessor32<scalar_t,dims,torch::RestrictPtrTraits>()
with templated functions and kernels, my throughput remained the same for float32
, but doesn’t improve much for float16
operations. PyTorch is faster again! Here are the stats using the same code as above, but in with torch.autocast(device_type='cuda', enabled=True, dtype=torch.float16)
:
Native PyTorch: 236 mil points/sec, 7.16 GB
CUDA kernels: 225 mil points/sec, 2.23 GB
As compared to float32, PyTorch is realizing almost a perfect 100% improvement, while my kernel is only improving by ~20%. I have verified that the dtype
s of the data are torch.float16
, casted properly with the CUDA functions returning that dtype
as well.
I decorated the tested function as shown below:
import torch
from torch.amp import custom_fwd, custom_bwd
class EncodeCoordinates(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type="cuda", cast_inputs=torch.float16)
def forward(ctx, query_coordinates, rotations, scales, translations, feature_grids):
feature_vectors = _C.encodeForward(query_coordinates, rotations,
scales, translations, feature_grids)
ctx.save_for_backward(query_coordinates, rotations,
scales, translations, feature_grids)
return feature_vectors
@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
query_coordinates, rotations, scales, \
translations, feature_grids = ctx.saved_tensors
grad_feature_grids = _C.encodeBackward(query_coordinates,
rotations, scales, translations, feature_grids, grad_output)
return None, None, None, None, grad_feature_grids
The C++ function bound with pybind looks like this:
torch::Tensor encode_forward(
const torch::Tensor& query_points,
const torch::Tensor& rotations,
const torch::Tensor& scales,
const torch::Tensor& translations,
const torch::Tensor& feature_grids)
{
const auto num_points = query_points.size(0);
const auto num_grids = feature_grids.size(0);
const auto features_per_grid = feature_grids.size(1);
const auto D = feature_grids.size(2);
const auto H = feature_grids.size(3);
const auto W = feature_grids.size(4);
auto options = torch::TensorOptions().dtype(query_points.dtype()).device(query_points.device());
torch::Tensor out_features = torch::empty({num_points, num_grids*features_per_grid}, options);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(query_points.scalar_type(), "launch_encode_forward", ([&] {
launch_encode_forward<scalar_t>(
query_points,
rotations,
scales,
translations,
feature_grids,
out_features
);
}));
return out_features;
}
And the relevant CUDA code looks like this
template <typename scalar_t>
struct scalar_t3 {
scalar_t x;
scalar_t y;
scalar_t z;
__device__ __host__ scalar_t3() : x(0), y(0), z(0) {}
__device__ __host__ scalar_t3(scalar_t x_, scalar_t y_, scalar_t z_) : x(x_), y(y_), z(z_) {}
};
template <typename scalar_t>
__host__ __device__ __forceinline__ scalar_t3<scalar_t> make_scalar_t3(scalar_t x, scalar_t y, scalar_t z) {
return scalar_t3<scalar_t>(x, y, z);
}
template <typename scalar_t>
void launch_encode_forward(
const torch::Tensor& query_points,
const torch::Tensor& rotations,
const torch::Tensor& scales,
const torch::Tensor& translations,
const torch::Tensor& feature_grids,
torch::Tensor& output_features)
{
const int num_points = query_points.size(0);
const int num_grids = rotations.size(0);
// Allocate memory for rotation matrices
scalar_t* rotation_matrices;
cudaMalloc(&rotation_matrices, num_grids * 3 * 3 * sizeof(scalar_t));
auto blocksPerGrid = (num_grids + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
quaternionScaleToRotationMatrix<<<blocksPerGrid, THREADS_PER_BLOCK>>>(
rotations.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
scales.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
rotation_matrices
);
dim3 threadsPerBlock(THREADS_PER_BLOCK, 1);
dim3 numBlocks((num_points + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, num_grids);
encodeForwardKernel<<<numBlocks, threadsPerBlock>>>(
query_points.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
rotation_matrices,
translations.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
feature_grids.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
output_features.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>()
);
// Free the allocated memory
cudaFree(rotation_matrices);
}
template <typename scalar_t>
__global__ void encodeForwardKernel(
const at::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> query_points,
const scalar_t* rotation_matrices,
const at::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> translations,
const at::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> feature_grids,
at::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> output_features) {
const auto point_idx = blockIdx.x * blockDim.x + threadIdx.x;
const auto grid_idx = blockIdx.y;
if (grid_idx >= feature_grids.size(0) || point_idx >= query_points.size(0)) return;
scalar_t3<scalar_t> point = make_scalar_t3<scalar_t>(query_points[point_idx][0], query_points[point_idx][1], query_points[point_idx][2]);
scalar_t3<scalar_t> point_t = transformPoint<scalar_t>(grid_idx, rotation_matrices, translations, point);
trilinearInterpolate<scalar_t>(
grid_idx,
point_idx,
feature_grids,
point_t,
output_features
);
}
I also made sure the various precisions are being compiled in the CUDA code with
template void launch_encode_forward<float>(const at::Tensor&, const at::Tensor&, const at::Tensor&, const at::Tensor&, const at::Tensor&, at::Tensor&);
template void launch_encode_forward<double>(const at::Tensor&, const at::Tensor&, const at::Tensor&, const at::Tensor&, const at::Tensor&, at::Tensor&);
template void launch_encode_forward<c10::Half>(const at::Tensor&, const at::Tensor&, const at::Tensor&, const at::Tensor&, const at::Tensor&, at::Tensor&);
added at the end of mu .cu
file with all the CUDA code.
All of my code is here, not included in this post for brevity (and because I suspect it might just come down to writing better kernels for fp16).
I had trouble finding documentation covering AMP with custom CUDA operations, and most of my support was with Claude through Cursor, so I’m wondering if there is something I’m missing to make sure the AMP code is working properly. If everything looks correct here, are there tips/resources/examples of CUDA kernels written to work with AMP I could adjust my code to follow?