Batched bincount

Hi,
This is what I got

def batched_bincount(x, dim, max_value):
    target = ch.zeros(x.shape[0], max_value, dtype=x.dtype, device=x.device)
    values = ch.ones_like(x)
    target.scatter_add_(dim, x, values)
    return targe
1 Like