I had the following code snippet and wanted to check the values of ci and cj. (I have removed most of the code in the kernel for readability)
#include <stdio.h>
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
namespace {
template <typename scalar_t>
__global__ void sparse_cuda_index_v2_kernel(
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> values,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> indices,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> queries,
const torch::PackedTensorAccessor<scalar_t,1,torch::RestrictPtrTraits,size_t> num_clusters,
const torch::PackedTensorAccessor<scalar_t,1,torch::RestrictPtrTraits,size_t> cluster_labels,
const torch::PackedTensorAccessor<scalar_t,1,torch::RestrictPtrTraits,size_t> start_idx,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> sumij,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> sumji,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> outputs)
{
const long ci = threadIdx.x;
const long cj = blockIdx.x;
printf("%d\n",cj);
}
} // namespace
std::vector<torch::Tensor> sparse_cuda_index_v2(
torch::Tensor values,
torch::Tensor indices,
torch::Tensor queries,
torch::Tensor num_clusters,
torch::Tensor cluster_labels,
torch::Tensor outputs) {
const long Nc = queries.size(0);
const long threads = Nc; //Value is 5
const dim3 blocks(Nc,1);
torch::Tensor sumij = torch::zeros({queries.size(0),values.size(0)}).to(torch::Device(torch::kCUDA, 0));
torch::Tensor sumji = torch::zeros({queries.size(0),values.size(0)}).to(torch::Device(torch::kCUDA, 0));
torch::Tensor start_idx = torch::zeros(queries.size(0)).to(torch::Device(torch::kCUDA, 0));
for(int i=1;i<num_clusters.size(0);i++){
start_idx[i] = start_idx[i-1] + num_clusters[i-1];
}
int64_t total = (start_idx[queries.size(0)-1]+num_clusters[queries.size(0)-1]).item<int64_t>();
assert(total==cluster_labels.size(0));
AT_DISPATCH_FLOATING_TYPES(values.type(), "indexing_v2", ([&] {
sparse_cuda_index_v2_kernel<scalar_t><<<blocks, threads>>>(
values.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
indices.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
queries.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
num_clusters.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(),
cluster_labels.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(),
start_idx.packed_accessor<scalar_t,1,torch::RestrictPtrTraits,size_t>(),
sumij.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
sumji.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
outputs.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
}));
return {outputs};
}
This prints all the possible values of cj (0,1,2,3,4) but when I add a print for ci. i.e replace
printf("%d\n",cj);
with
printf("%d %d\n",ci, cj);
cj prints only 0. What’s the reason for this?