Batched bincount

I would l like to apply torch.bincount on batches of vectors supposing that the maximal possible value is the same for every vector of the batch, how can I do it ? Because bincount does not support batch and a for-loop would not be efficient.

Thank you

You can just use scatter_add_(dim, index, other) for batched bincount, where ‘index’ is your input and ‘other’ is just torch.ones(…).

Hi, were you ever able to achieve a batched bincount? I’m facing the same issue where looping throughout a dim is not an option.

This is what I got

def batched_bincount(x, dim, max_value):
    target = ch.zeros(x.shape[0], max_value, dtype=x.dtype, device=x.device)
    values = ch.ones_like(x)
    target.scatter_add_(dim, x, values)
    return targe