Similar to computing the sum of data using
torch.index_add_(), how can I find the maximum value per index of data in a tensor given a set of indices.
For example, computing sums can be done as:
x = torch.FloatTensor([3,1,4,1,5,9,2,6]) idx = torch.LongTensor([0,1,1,0,0,3,2,2]) sum = torch.zeros(4) sum.index_add_(0, idx, x) 9 5 8 9 [torch.FloatTensor of size 4]
I cannot seem to find an efficient way to do the equivalent operation for maxima. I.e.:
x = torch.FloatTensor([3,1,4,1,5,9,2,6]) idx = torch.LongTensor([0,1,1,0,0,3,2,2]) sum = torch.zeros(4) # WHAT OPERATION(S) WILL GIVE ME: 5 4 6 9 [torch.FloatTensor of size 4]
Notice that the output is the maximum value given the associated index, rather than the sum.
How can I do this max operation efficiently?
Currently, I loop through the indices, calling
index_select for each index and then computing
torch.max over the selected data. This is much, much less efficient.
Here’s a runnable script that I’m using to compare. Obviously this scales with the number of indices desired.
import time N = 100000 # Num data points I = 500 # Num indices x = torch.rand(N) idx = (torch.rand(N) * I).floor().long() agg = torch.zeros(I) # Built-in indexing operation s = time.time() agg.index_add_(0, idx, x) t = time.time() - s print 'Builtin Indexing:', t #>>> 0.00680208206177 seconds on my laptop # Looping for max s = time.time() for i in xrange(I): sel = x[idx==i] if sel.shape > 0: agg[i] = torch.max(sel, 0) t = time.time() - s print 'Looping:', t #>>> 2.53458189964 seconds on my laptop