`index_max_()` functionality? (and others)

I am currently trying to figure out a good index_max_ solution that doesn’t require multiple copies of the data or a loop. This would work along the lines of index_add_ except the underlying operation is a max instead of addition. Does anyone have a solution?

Also, are there plans for more explicitly built-in functionality along the lines of index_add_, index_copy_, etc? Even better, it would be awesome to have some sort of lambda-style of overloading for these indexing operations.

In particular, I’m finding occasions where an index_max_ or index_avg_ would be useful. For example, if I have a list of indices corresponding to data (maybe nearest neighbor pointers or something similar) and I want to pool over the indices, I could use it.

For the averaging operation, I currently use index_add_ on my data and also on a torch.ones() vector and then do the appropriate broadcasted division. I suppose it may be more efficient to do that on the backend though.

1 Like

In case anyone else finds this: you can compute index max using scatter_reduce with amax reduction

import torch
input = torch.tensor([
    [1, 2, 3, 4, 5, 6, 3.5],
     [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 3.75]])
index = torch.tensor([[0, 0, 0, 1, 1, 2, 2],
                      [2, 0, 1, 1, 1, 1, 1]
torch.scatter_reduce(input, -1, index, reduce="amax", output_size=3)
# tensor([[3.0000, 5.0000, 6.0000],
#             [2.5000, 6.5000, 1.5000]])

Please refer to torch.scatter_reduce — PyTorch 1.11.0 documentation