Cast tensor according to template type

Dear all,

Assume this function:

template<typename scalar_t>
std::tuple<torch::Tensor,torch::Tensor> get_cheb_data(
        torch::Tensor & cheb_nodes,
        int &d,
        const std::string & gpu_device
        ){
    int n = (int) pow(cheb_nodes.size(0),d);
    torch::Tensor cheb_idx = torch::zeros({n,d}).toType(torch::kInt32).to(gpu_device);
    torch::Tensor cheb_data = torch::zeros({n,d}).toType(torch::kFloat32).to(gpu_device);
    dim3 block,grid;
    int shared;
    std::tie(block,grid,shared) =  get_kernel_launch_params<scalar_t>(d,n);
    get_cheb_idx_data<scalar_t><<<grid,block,shared>>>(
            cheb_nodes.packed_accessor32<scalar_t,1,torch::RestrictPtrTraits>(),
            cheb_data.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),
            cheb_idx.packed_accessor32<int,2,torch::RestrictPtrTraits>()
            );

    return std::make_tuple(cheb_idx,cheb_data);

}

The requirement is that

    torch::Tensor cheb_data = torch::zeros({n,d}).toType(torch::kFloat32).to(gpu_device);

has to have type “scalar_t”, where scalar_t could be anything from Half to Double. What is the best way to do this?

Thank you!

Best regards,
Robert