PyTorch - how to find the max values at certain indices ? (PyTorch equivalent of tf.math.unsorted_segment_max)

I posted this question on StackOverflow PyTorch - how to find the max values at certain indices ?? (PyTorch equivalent of tf.math.unsorted_segment_max) - Stack Overflow but have not gotten any suggestions so I’ll post again here.

In PyTorch, I need to efficiently find the max values at certain indices. TensorFlow has a built in function tf.math.unsorted_segment_max to achieve this, I’ll borrow an example from their documentation https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_max:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf

data = tf.constant([5, 1, 7, 2, 3, 4], dtype=tf.float32)
segment_idxs = tf.constant([0, 0, 1, 1, 0, 1], dtype=tf.int64)
num_segments = 2

max_at_idxs = tf.math.unsorted_segment_max(data, segment_idxs, num_segments)

print('\n' + 'max_at_idxs: ')
print(max_at_idxs)

output:

max_at_idxs: 
tf.Tensor([5. 7.], shape=(2,), dtype=float32)

I realize that logic-wise I could use a loop like this:

import torch

data = torch.tensor([5, 1, 7, 2, 3, 4], dtype=torch.float32)
segment_idxs = torch.tensor([0, 0, 1, 1, 0, 1], dtype=torch.int64)
num_segments = 2

max_at_idxs = torch.full(size=(num_segments, ), fill_value=-9999.9999, dtype=torch.float32)
for i, segment_idx in enumerate(segment_idxs):
    max_at_idxs[segment_idx] = torch.max(data[i], max_at_idxs[segment_idx])
# end for

print('\n' + 'max_at_idxs: ')
print(max_at_idxs)

However in my actual application data and segment_idxs are ~200,000 items long and this has to be done often so a loop is way too slow.

Similar questions have been asked in the PyTorch forum before, noteably: How to perform segment max?
Is there a pytorch implementation of tensor flow segment softmax? - #4 by tom
Pytorch equivalent to `tf.unsorted_segment_sum`
but none of these resolves my concern, so I figured I’d ask here.

I keep figuring there has to be a way to do this with fancy indexing, possibly in combination with torch.where or some other built-in function, but I can’t seem to work it out.

Also, I keep looking at the torch.maximum function, trying to figure a way to use the 2nd parameter as indices rather than values, but I can’t find a way to work this out either.

Any suggestions?

DerekG posted a great response PyTorch - how to find the max values at certain indices ?? (PyTorch equivalent of tf.math.unsorted_segment_max) - Stack Overflow, I’ll repeat the answer here using my own variable namings in case the SO link ever goes bad:

import torch

xs = torch.tensor([8, 2, 6, 1, 3, 9, 4, 5, 7], dtype=torch.float32)
group_idxs = torch.tensor([0, 0, 1, 1, 2, 2, 0, 1, 2], dtype=torch.int64)
num_groups = 3
xs_len = len(xs)

# Slow way, un-comment if desired
# max_by_group = torch.full(size=(num_groups, ), fill_value=-9999.9999, dtype=torch.float32)
# for i, x in enumerate(xs):
#     max_by_group[group_idxs[i]] = max(x, max_by_group[group_idxs[i]])
# # end for

xs_exp = xs.unsqueeze(0).expand(num_groups, xs_len)

print('\n' + 'xs_exp: ')
print(xs_exp)

group_idxs_exp = group_idxs.unsqueeze(0).expand(num_groups, xs_len)

print('\n' + 'group_idxs_exp: ')
print(group_idxs_exp)

row_idxs = torch.arange(start=0, end=num_groups).unsqueeze(1).expand(num_groups, xs_len)

print('\n' + 'row_idxs: ')
print(row_idxs)

zero_or_neg_inf = torch.where(group_idxs_exp == row_idxs, 0, -torch.inf)

print('\n' + 'zero_or_neg_inf: ')
print(zero_or_neg_inf)

xs_exp_masked = xs_exp + zero_or_neg_inf

print('\n' + 'xs_exp_masked: ')
print(xs_exp_masked)

max_by_group, _ = torch.max(xs_exp_masked, dim=1)

print('\n' + 'max_by_group: ')
print(max_by_group)

print('\n')