Similar to computing the sum of data using torch.index_add_()
, how can I find the maximum value per index of data in a tensor given a set of indices.
For example, computing sums can be done as:
x = torch.FloatTensor([3,1,4,1,5,9,2,6])
idx = torch.LongTensor([0,1,1,0,0,3,2,2])
sum = torch.zeros(4)
sum.index_add_(0, idx, x)
9
5
8
9
[torch.FloatTensor of size 4]
I cannot seem to find an efficient way to do the equivalent operation for maxima. I.e.:
x = torch.FloatTensor([3,1,4,1,5,9,2,6])
idx = torch.LongTensor([0,1,1,0,0,3,2,2])
sum = torch.zeros(4)
# WHAT OPERATION(S) WILL GIVE ME:
5
4
6
9
[torch.FloatTensor of size 4]
Notice that the output is the maximum value given the associated index, rather than the sum.
How can I do this max operation efficiently?
Currently, I loop through the indices, calling index_select
for each index and then computing torch.max
over the selected data. This is much, much less efficient.
Here’s a runnable script that I’m using to compare. Obviously this scales with the number of indices desired.
import time
N = 100000 # Num data points
I = 500 # Num indices
x = torch.rand(N)
idx = (torch.rand(N) * I).floor().long()
agg = torch.zeros(I)
# Built-in indexing operation
s = time.time()
agg.index_add_(0, idx, x)
t = time.time() - s
print 'Builtin Indexing:', t
#>>> 0.00680208206177 seconds on my laptop
# Looping for max
s = time.time()
for i in xrange(I):
sel = x[idx==i]
if sel.shape[0] > 0:
agg[i] = torch.max(sel, 0)[0][0]
t = time.time() - s
print 'Looping:', t
#>>> 2.53458189964 seconds on my laptop