Getting indices of a histogram's elements with pytorch

Hi,
Is there a way to return a Tensor representing the indices of a histogram into which each element of values would be binned, like with tensorflow.histogram_fixed_width_bins

I’m not aware of a built-in method, but something like this might work:

x = torch.tensor([1., 2, 1, 8, 6, 6, 8, 12])
hist = torch.histogram(x, bins=4, range=(0., 10.))
print(hist.hist)
# tensor([3., 0., 2., 2.])
print(hist.bin_edges)
# tensor([ 0.0000,  2.5000,  5.0000,  7.5000, 10.0000])

idx = (x.unsqueeze(1) - hist.bin_edges.unsqueeze(0)).cumsum(1).argmax(1)
print(idx)
# tensor([0, 0, 0, 3, 2, 2, 3, 4])