Pytorch equivalent function for np.minimum.at or np.maximum.at

Hi,

I want to prototype a function in pytorch and in order to achieve it, I need a function that behaves similar to scatter_add but for the maximum instead of add.
So I do have a list of indices and according values - both large, multidimensional and on GPU.
Some indices will collide – this is the crucial part! – and in such a case I want the maximum to be kept.
(Of course the minimum would also work just fine).
It would also be great to have this for argmax, argmin as well.

Here is a minimal example in numpy:

import numpy as np

ids  = np.array([1,2,3,4,1,2,3,1])   # There are multiple 1s here
vals = np.array([0,0,0,1,2,3,1,1])  # 0,2,1 will written at index 1
maxs = np.zeros_like(vals)
avgs = np.zeros_like(vals)

np.maximum.at(maxs,ids,vals)
np.add.at(avgs,ids,vals)
print (maxs)
print (avgs)

#ids:0 1 2 3 4 5 6 7
>>> [0 2 3 1 1 0 0 0]
>>> [0 3 3 1 1 0 0 0]

Found a solution myself:

There is as library called pytorch_scatter that provides many different scatter operations (add, div, max, mean, min, mul, std). The scatter_max returns values and indices which allows to also directly use it for argmax operations. The operations also come with the according backwards paths.

1 Like