I posted this question on StackOverflow PyTorch - how to find the max values at certain indices ?? (PyTorch equivalent of tf.math.unsorted_segment_max) - Stack Overflow but have not gotten any suggestions so I’ll post again here.
In PyTorch, I need to efficiently find the max values at certain indices. TensorFlow has a built in function
tf.math.unsorted_segment_max to achieve this, I’ll borrow an example from their documentation https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_max:
import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf data = tf.constant([5, 1, 7, 2, 3, 4], dtype=tf.float32) segment_idxs = tf.constant([0, 0, 1, 1, 0, 1], dtype=tf.int64) num_segments = 2 max_at_idxs = tf.math.unsorted_segment_max(data, segment_idxs, num_segments) print('\n' + 'max_at_idxs: ') print(max_at_idxs)
max_at_idxs: tf.Tensor([5. 7.], shape=(2,), dtype=float32)
I realize that logic-wise I could use a loop like this:
import torch data = torch.tensor([5, 1, 7, 2, 3, 4], dtype=torch.float32) segment_idxs = torch.tensor([0, 0, 1, 1, 0, 1], dtype=torch.int64) num_segments = 2 max_at_idxs = torch.full(size=(num_segments, ), fill_value=-9999.9999, dtype=torch.float32) for i, segment_idx in enumerate(segment_idxs): max_at_idxs[segment_idx] = torch.max(data[i], max_at_idxs[segment_idx]) # end for print('\n' + 'max_at_idxs: ') print(max_at_idxs)
However in my actual application
segment_idxs are ~200,000 items long and this has to be done often so a loop is way too slow.
Similar questions have been asked in the PyTorch forum before, noteably: How to perform segment max?
Is there a pytorch implementation of tensor flow segment softmax? - #4 by tom
Pytorch equivalent to `tf.unsorted_segment_sum`
but none of these resolves my concern, so I figured I’d ask here.
I keep figuring there has to be a way to do this with fancy indexing, possibly in combination with
torch.where or some other built-in function, but I can’t seem to work it out.
Also, I keep looking at the
torch.maximum function, trying to figure a way to use the 2nd parameter as indices rather than values, but I can’t find a way to work this out either.