Compute maximum value per index in a tensor given set of indices

Similar to computing the sum of data using torch.index_add_(), how can I find the maximum value per index of data in a tensor given a set of indices.

For example, computing sums can be done as:

x = torch.FloatTensor([3,1,4,1,5,9,2,6])
idx = torch.LongTensor([0,1,1,0,0,3,2,2])
sum = torch.zeros(4)
sum.index_add_(0, idx, x)

 9
 5
 8
 9
[torch.FloatTensor of size 4]

I cannot seem to find an efficient way to do the equivalent operation for maxima. I.e.:

x = torch.FloatTensor([3,1,4,1,5,9,2,6])
idx = torch.LongTensor([0,1,1,0,0,3,2,2])
sum = torch.zeros(4)
# WHAT OPERATION(S) WILL GIVE ME:

 5
 4
 6
 9
[torch.FloatTensor of size 4]

Notice that the output is the maximum value given the associated index, rather than the sum.

How can I do this max operation efficiently?

Currently, I loop through the indices, calling index_select for each index and then computing torch.max over the selected data. This is much, much less efficient.

Here’s a runnable script that I’m using to compare. Obviously this scales with the number of indices desired.

import time
N = 100000 # Num data points
I = 500 # Num indices
x = torch.rand(N)
idx = (torch.rand(N) * I).floor().long()
agg = torch.zeros(I)

# Built-in indexing operation
s = time.time()
agg.index_add_(0, idx, x)
t = time.time() - s
print 'Builtin Indexing:', t
#>>> 0.00680208206177 seconds on my laptop

# Looping for max
s = time.time()
for i in xrange(I):
  sel = x[idx==i]
  if sel.shape[0] > 0:
    agg[i] = torch.max(sel, 0)[0][0]

t = time.time() - s
print 'Looping:', t
#>>> 2.53458189964 seconds on my laptop

Did you ever find a solution to this? I have the same issue.