Batched torch.histc

Will torch.histc support batched inputs in the future? Or any tricky way to achieve this?

Assume I have Bxc input, B is the batch_size

histc_input = torch.rand(B, c)
histc_output = torch.histc(histc_input, bins=100, min=0, max=1)

I want a histc_output with the shape Bx100.
Do you think it’s possible to implement?

I found one possible solution using one_hot here

def batch_histogram(data_tensor, num_classes=-1):
    """
    Computes histograms of integral values, even if in batches (as opposed to torch.histc and torch.histogram).
    Arguments:
        data_tensor: a D1 x ... x D_n torch.LongTensor
        num_classes (optional): the number of classes present in data.
                                If not provided, tensor.max() + 1 is used (an error is thrown if tensor is empty).
    Returns:
        A D1 x ... x D_{n-1} x num_classes 'result' torch.LongTensor,
        containing histograms of the last dimension D_n of tensor,
        that is, result[d_1,...,d_{n-1}, c] = number of times c appears in tensor[d_1,...,d_{n-1}].
    """
    return torch.nn.functional.one_hot(data_tensor, num_classes).sum(dim=-2)

But data_tensor must be converted to Torch.LongTensor or Torch.cuda.LongTensor depending on the device, if input images with shape (b,3,h,w) are in the value range (0,1), and if we want a (b,3,num_classes) histogram output for each channel, each batch,

Then what we need to do is

num_class = 10 # take 10 as an example
b, c, h, w = image.shape
output = batch_histogram(data_tensor=(image*num_class).type(torch.LongTensor).reshape(b,c,h*w), num_class=num_class)

The one_hot function occupies lots of memory as expected.