I posted this question on StackOverflow PyTorch - how to find the max values at certain indices ?? (PyTorch equivalent of tf.math.unsorted_segment_max) - Stack Overflow but have not gotten any suggestions so I’ll post again here.

In PyTorch, I need to efficiently find the max values at certain indices. TensorFlow has a built in function `tf.math.unsorted_segment_max`

to achieve this, I’ll borrow an example from their documentation https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_max:

```
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
data = tf.constant([5, 1, 7, 2, 3, 4], dtype=tf.float32)
segment_idxs = tf.constant([0, 0, 1, 1, 0, 1], dtype=tf.int64)
num_segments = 2
max_at_idxs = tf.math.unsorted_segment_max(data, segment_idxs, num_segments)
print('\n' + 'max_at_idxs: ')
print(max_at_idxs)
```

output:

```
max_at_idxs:
tf.Tensor([5. 7.], shape=(2,), dtype=float32)
```

I realize that logic-wise I could use a loop like this:

```
import torch
data = torch.tensor([5, 1, 7, 2, 3, 4], dtype=torch.float32)
segment_idxs = torch.tensor([0, 0, 1, 1, 0, 1], dtype=torch.int64)
num_segments = 2
max_at_idxs = torch.full(size=(num_segments, ), fill_value=-9999.9999, dtype=torch.float32)
for i, segment_idx in enumerate(segment_idxs):
max_at_idxs[segment_idx] = torch.max(data[i], max_at_idxs[segment_idx])
# end for
print('\n' + 'max_at_idxs: ')
print(max_at_idxs)
```

However in my actual application `data`

and `segment_idxs`

are ~200,000 items long and this has to be done often so a loop is way too slow.

Similar questions have been asked in the PyTorch forum before, noteably: How to perform segment max?

Is there a pytorch implementation of tensor flow segment softmax? - #4 by tom

Pytorch equivalent to `tf.unsorted_segment_sum`

but none of these resolves my concern, so I figured I’d ask here.

I keep figuring there has to be a way to do this with fancy indexing, possibly in combination with `torch.where`

or some other built-in function, but I can’t seem to work it out.

Also, I keep looking at the `torch.maximum`

function, trying to figure a way to use the 2nd parameter as indices rather than values, but I can’t find a way to work this out either.

Any suggestions?