Templated Cooperative Group Kernel AT Dispatch

Has anyone had success with cooperative groups for grid synchronisation in their kernel when implementing new functions?

I’m getting unresolved overload errors when trying to use cudaLaunchCooperativeKernel(), whereas it was previously working fine when launching normally kernel<<<>>>(*args). However, I need grid syncrhonisation during reduction so need to launch a cooperative kernel.

template <typename scalar_t, std::size_t dims>
using TensorAccR = torch::PackedTensorAccessor32<scalar_t,dims,torch::RestrictPtrTraits>;

template <typename scalar_t>
__global__ void a_kernel(const TensorAccR<scalar_t, 4> in_features, 
TensorAccR<scalar_t, 2> out_features)
{
   cooperative_groups::grid_group g = cooperative_groups::this_grid();
   /*do things*/
   g.sync();
}

void my_launcher(torch::Tensor features_in, torch::Tensor features_out)
{
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(features.scalar_type(), 
   "instance_pool_forward_kernel",
   [&] {
      TensorAccR<scalar_t, 4> features_in_acc = features_in.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>();
      TensorAccR<scalar_t, 2> features_out_acc = features_out.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>();

      void *kernelArgs[] = { (void*)&features_in_acc, (void*)&features_out_acc};

      cudaLaunchCooperativeKernel(a_kernel<scalar_t>, gridDims,
                    blockDims, kernelArgs, s_mem, at::cuda::getCurrentCUDAStream());
   });
}

Sitting back down after lunch, turns out all was missing was (void*) cast at the function call (shown below), so now it at least compiles fine.

cudaLaunchCooperativeKernel((void*)a_kernel<scalar_t>, gridDims,
                    blockDims, kernelArgs, s_mem, at::cuda::getCurrentCUDAStream());