I’d like to compute a flat 5-bin histogram counts using scatter_(..., reduce = 'add')
. I expect the following to produce [0.0, 16.0, 0.0, 0.0, 0.0]
. What am I doing wrong dimension-wise? Also, the error message is hard to parse (especially the latter part).
Thanks!
import torch
histsize = 5
Z = torch.ones(16, dtype = torch.int64)
print(torch.zeros((histsize,), dtype = torch.float32).scatter_(-1, Z, 1.0, reduce = 'add'))
# torch.zeros((histsize,), dtype = torch.float32).scatter_(-1, Z, 1.0, reduce = 'add')
RuntimeError: Expected index [16] to be smaller than self [5] apart from dimension 0 and to be smaller size than src [5]
This works:
import torch
histsize = 5
Z = torch.ones(16, dtype = torch.int64)
torch.zeros((histsize,), dtype = torch.int64).scatter_add_(-1, Z, torch.ones_like(Z))