Hi,
I want to prototype a function in pytorch and in order to achieve it, I need a function that behaves similar to scatter_add but for the maximum instead of add.
So I do have a list of indices and according values - both large, multidimensional and on GPU.
Some indices will collide – this is the crucial part! – and in such a case I want the maximum to be kept.
(Of course the minimum would also work just fine).
It would also be great to have this for argmax, argmin as well.
Here is a minimal example in numpy:
import numpy as np
ids = np.array([1,2,3,4,1,2,3,1]) # There are multiple 1s here
vals = np.array([0,0,0,1,2,3,1,1]) # 0,2,1 will written at index 1
maxs = np.zeros_like(vals)
avgs = np.zeros_like(vals)
np.maximum.at(maxs,ids,vals)
np.add.at(avgs,ids,vals)
print (maxs)
print (avgs)
#ids:0 1 2 3 4 5 6 7
>>> [0 2 3 1 1 0 0 0]
>>> [0 3 3 1 1 0 0 0]