Hi, I would l like to apply torch.bincount on batches of vectors supposing that the maximal possible value is the same for every vector of the batch, how can I do it ? Because bincount does not support batch and a for-loop would not be efficient.
Thank you
You can just use scatter_add_(dim, index, other) for batched bincount, where ‘index’ is your input and ‘other’ is just torch.ones(…).
Hi, were you ever able to achieve a batched bincount? I’m facing the same issue where looping throughout a dim is not an option.
Hi, This is what I got
def batched_bincount(x, dim, max_value): target = ch.zeros(x.shape[0], max_value, dtype=x.dtype, device=x.device) values = ch.ones_like(x) target.scatter_add_(dim, x, values) return targe