Hi all.
I want to implement a function like “reverse” map of inverse
value of torch.unique
.
For example, torch.unique
function can return the unique values uniques
of long type input x
and a inverse tensor that is a mapping from uniques
to x
.
x = torch.LongTensor([9, 10, 9, 9, 10, 9])
uniques, inverse, counts = torch.unique(x, return_inverse=True, return_counts=True)
# uniques = [9, 10]
# inverse = [0, 1, 0, 0, 1, 0]
# counts = [4, 2]
print((uniques[inverse] == x).all()) # True
For my question, is there some efficient way to get “reverse” inverse back_map
that mapping from x
to uniques
?
def reverse_unique(x): ...
uniques, inverse, counts, back_map = reverse_unique(x)
# uniques = [9, 10]
# inverse = [0, 1, 0, 0, 1, 0]
# counts = [4, 2]
# back_map = [0, 2, 3, 5, 1, 4]
print((x[back_map] == uniques.repeat_interleave(counts)).all()) # True
In above code, the back_map
maps the values of inverse
to the position of the input x
.
I know it not a difficult thing implements this function with python loop, but in my case where the input x
size can reach e8, so the time overhead is intolerable.
So, Is there any high level implementation using pytorch api or the cuda kernel(I tried to use cuda extension to parallelize it but my cuad kernel is slow extremely )?
__global__ void unique_back_map_kernel(
int32_t num_uni,
int32_t num_x,
int64_t* __restrict__ uniques,
int64_t* __restrict__ cumsum_counts,
int64_t* __restrict__ x,
int64_t* __restrict__ out) {
int32_t n = blockIdx.x * blockDim.x + threadIdx.x;
if (n >= num_uni) {
return;
}
size_t counts = 0;
auto idx = __ldg(&cumsum_counts[n]);
#pragma unroll
for (int64_t i = 0; i < num_x; ++i) {
if (x[i] == uniques[n]) {
out[idx + counts] = i;
counts++;
}
}
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> reverse_unique(at::Tensor x) {
TORCH_CHECK(x.dtype() == at::kLong && x.device().is_cuda());
auto [uniques, inverse, counts] = at::_unique2(x, false, true, true);
counts.cumsum_(0);
// python: cumsum_counts = torch.cat([torch.tensor([0]), cumsum_counts[:-1]])
auto cumsum_counts = at::cat({at::zeros({1}, counts.options()), counts.slice(0, 0, -1)});
auto back_map = at::empty_like(x);
int32_t threads = (uniques.numel() > 256) ? 256 : 32;
int32_t blocks = (uniques.numel() + threads - 1) / threads;
unique_back_map_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
uniques.numel(),
x.numel(),
uniques.data_ptr<int64_t>(),
cumsum_counts.data_ptr<int64_t>(),
x.data_ptr<int64_t>(),
back_map.data_ptr<int64_t>());
return std::make_tuple(uniques, inverse, cumsum_counts, back_map);
}