Hi,
This is what I got
def batched_bincount(x, dim, max_value):
target = ch.zeros(x.shape[0], max_value, dtype=x.dtype, device=x.device)
values = ch.ones_like(x)
target.scatter_add_(dim, x, values)
return targe
Hi,
This is what I got
def batched_bincount(x, dim, max_value):
target = ch.zeros(x.shape[0], max_value, dtype=x.dtype, device=x.device)
values = ch.ones_like(x)
target.scatter_add_(dim, x, values)
return targe