I posted this question on StackOverflow PyTorch - how to find the max values at certain indices ?? (PyTorch equivalent of tf.math.unsorted_segment_max) - Stack Overflow but have not gotten any suggestions so I’ll post again here.
In PyTorch, I need to efficiently find the max values at certain indices. TensorFlow has a built in function tf.math.unsorted_segment_max
to achieve this, I’ll borrow an example from their documentation https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_max:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
data = tf.constant([5, 1, 7, 2, 3, 4], dtype=tf.float32)
segment_idxs = tf.constant([0, 0, 1, 1, 0, 1], dtype=tf.int64)
num_segments = 2
max_at_idxs = tf.math.unsorted_segment_max(data, segment_idxs, num_segments)
print('\n' + 'max_at_idxs: ')
print(max_at_idxs)
output:
max_at_idxs:
tf.Tensor([5. 7.], shape=(2,), dtype=float32)
I realize that logic-wise I could use a loop like this:
import torch
data = torch.tensor([5, 1, 7, 2, 3, 4], dtype=torch.float32)
segment_idxs = torch.tensor([0, 0, 1, 1, 0, 1], dtype=torch.int64)
num_segments = 2
max_at_idxs = torch.full(size=(num_segments, ), fill_value=-9999.9999, dtype=torch.float32)
for i, segment_idx in enumerate(segment_idxs):
max_at_idxs[segment_idx] = torch.max(data[i], max_at_idxs[segment_idx])
# end for
print('\n' + 'max_at_idxs: ')
print(max_at_idxs)
However in my actual application data
and segment_idxs
are ~200,000 items long and this has to be done often so a loop is way too slow.
Similar questions have been asked in the PyTorch forum before, noteably: How to perform segment max?
Is there a pytorch implementation of tensor flow segment softmax? - #4 by tom
Pytorch equivalent to `tf.unsorted_segment_sum`
but none of these resolves my concern, so I figured I’d ask here.
I keep figuring there has to be a way to do this with fancy indexing, possibly in combination with torch.where
or some other built-in function, but I can’t seem to work it out.
Also, I keep looking at the torch.maximum
function, trying to figure a way to use the 2nd parameter as indices rather than values, but I can’t find a way to work this out either.
Any suggestions?