I found one possible solution using one_hot
here
def batch_histogram(data_tensor, num_classes=-1):
"""
Computes histograms of integral values, even if in batches (as opposed to torch.histc and torch.histogram).
Arguments:
data_tensor: a D1 x ... x D_n torch.LongTensor
num_classes (optional): the number of classes present in data.
If not provided, tensor.max() + 1 is used (an error is thrown if tensor is empty).
Returns:
A D1 x ... x D_{n-1} x num_classes 'result' torch.LongTensor,
containing histograms of the last dimension D_n of tensor,
that is, result[d_1,...,d_{n-1}, c] = number of times c appears in tensor[d_1,...,d_{n-1}].
"""
return torch.nn.functional.one_hot(data_tensor, num_classes).sum(dim=-2)
But data_tensor
must be converted to Torch.LongTensor
or Torch.cuda.LongTensor
depending on the device, if input images with shape (b,3,h,w) are in the value range (0,1), and if we want a (b,3,num_classes) histogram output for each channel, each batch,
Then what we need to do is
num_class = 10 # take 10 as an example
b, c, h, w = image.shape
output = batch_histogram(data_tensor=(image*num_class).type(torch.LongTensor).reshape(b,c,h*w), num_class=num_class)
The one_hot
function occupies lots of memory as expected.