Can this for loop be avoided using Pytorch operations?

Hey! I implemented the following algorithm using python for loops, which are, afaik, not very efficient:

sorted, indices = torch.sort(torch.abs(torch.reshape(x, (-1,))), descending=True)
running_sum = -r # r is some positive number
for idx, element in enumerate(sorted):
    running_sum += element
    if element <= (running_sum) / (idx + 1):
         running_sum -= element
         idx -= 1
         break

Is there any way to do this faster using Pytorch operations? The crucial part is finding the first entry of a large sorted tensor such that said entry is less-or-equal than the mean of all previous entries.

Looking forward for your ideas! Thanks a lot!

Hi Max!

Yes, you can do this with pytorch tensor operations. Use cumsum()
divided by something like arange() to get the running mean. Then
test that against your sorted tensor to find when your condition first
holds.

Here is a pytorch version 0.3.0 script:

import torch
torch.__version__

torch.manual_seed (2020)

x = torch.randn ((3, 2))
r = 2.0

# pytorch tensor operations
sorted, indices = torch.sort(torch.abs( x.view (-1,) ), descending=True)               # sort data
running_mean = (torch.cumsum (sorted, 0) - r) / torch.arange (1, sorted.numel() + 1)   # get cumulative mean
is_le = sorted <= running_mean          # test condition
idx = is_le.numel() - is_le.sum() - 1   # get index
print ('idx =', idx)
print ('sorted[idx + 1] =', sorted[idx + 1])
print ('running_mean[idx + 1] =', running_mean[idx + 1])


# loop
# sorted, indices = torch.sort(torch.abs(torch.reshape(x, (-1,))), descending=True)
sorted, indices = torch.sort(torch.abs( x.view (-1,) ), descending=True)
running_sum = -r # r is some positive number
for idx, element in enumerate(sorted):
    running_sum += element
    if element <= (running_sum) / (idx + 1):
         running_sum -= element
         idx -= 1
         break

print ('x =', x)
print ('sorted =', sorted)
print ('running_sum =', running_sum)
print ('idx =', idx)
print ('sorted[idx + 1] =', sorted[idx + 1])
print ('(running_sum + sorted[idx + 1]) / (idx + 2) =', (running_sum + sorted[idx + 1]) / (idx + 2))

And here is its output

>>> import torch
>>> torch.__version__
'0.3.0b0+591e73e'
>>>
>>> torch.manual_seed (2020)
<torch._C.Generator object at 0x0000020EC7916630>
>>>
>>> x = torch.randn ((3, 2))
>>> r = 2.0
>>>
>>> # pytorch tensor operations
... sorted, indices = torch.sort(torch.abs( x.view (-1,) ), descending=True)               # sort data
>>> running_mean = (torch.cumsum (sorted, 0) - r) / torch.arange (1, sorted.numel() + 1)   # get cumulative mean
>>> is_le = sorted <= running_mean          # test condition
>>> idx = is_le.numel() - is_le.sum() - 1   # get index
>>> print ('idx =', idx)
idx = 3
>>> print ('sorted[idx + 1] =', sorted[idx + 1])
sorted[idx + 1] = 0.40786537528038025
>>> print ('running_mean[idx + 1] =', running_mean[idx + 1])
running_mean[idx + 1] = 0.6055201292037964
>>>
>>>
>>> # loop
... # sorted, indices = torch.sort(torch.abs(torch.reshape(x, (-1,))), descending=True)
... sorted, indices = torch.sort(torch.abs( x.view (-1,) ), descending=True)
>>> running_sum = -r # r is some positive number
>>> for idx, element in enumerate(sorted):
...     running_sum += element
...     if element <= (running_sum) / (idx + 1):
...          running_sum -= element
...          idx -= 1
...          break
...
>>> print ('x =', x)
x =
 1.2372 -0.9604
 1.5415 -0.4079
 0.8806  0.0529
[torch.FloatTensor of size 3x2]

>>> print ('sorted =', sorted)
sorted =
 1.5415
 1.2372
 0.9604
 0.8806
 0.4079
 0.0529
[torch.FloatTensor of size 6]

>>> print ('running_sum =', running_sum)
running_sum = 2.619735360145569
>>> print ('idx =', idx)
idx = 3
>>> print ('sorted[idx + 1] =', sorted[idx + 1])
sorted[idx + 1] = 0.40786537528038025
>>> print ('(running_sum + sorted[idx + 1]) / (idx + 2) =', (running_sum + sorted[idx + 1]) / (idx + 2))
(running_sum + sorted[idx + 1]) / (idx + 2) = 0.6055201470851899

Best.

K. Frank

1 Like

Thanks for the fast and intuitive answer! One question remains: I was kind of surprised by your way to find idx. Is that really the most efficient way of finding the index of the largest False element of is_le? numel() is probably constant in the size of is_le, but sum() should be linear?

Hi Max!

First, yes, sum() is linear. (But, as an aside, finding the first True
index is also linear, but see below).

The key point is that sum() is implemented as a pytorch tensor
operation, so you get that potentially important benefit.

Note, many libraries offer a find_first type of function for searching
strings or arrays. I’m not aware that pytorch has such (at least not
in my old version). One could consider using pytorch’s argmax()
function, but in this case several elements of is_le take on the
maximum value, and I don’t think that pytorch guarantees that
argmax() will return the index of the first maximum element.

Now (other than the initial sort) this entire algorithm is linear. But
this algorithm looks at the entire sorted vector – whether or not
it needs to. You only have to calculate your running mean until your
condition is satisfied. So if your condition is equally likely to be
satisfied anywhere in the vector, then, on average, you would only
have to look at half of sorted and compute only half of the running
mean. Cheaper, but still linear.

If your condition were strongly-enough biased to occur towards the
beginning of your vector, the algorithm could even become order 1
(“O(1)”), for the average case. But, given that pytorch doesn’t have
anything built into it that does this for you, you’re almost certainly
better of using the linear – but optimized – pytorch tensor operations,
than writing your own non-tensor loop (even in the hypothetical
average-case O(1) example).

The same comments apply to finding the first True index in is_le.
This, in general, is linear, but could be cheaper in operations if you
run a loop that stops “early” when it finds that index. But the tensor
operation sum() is likely to be cheaper in time (but not operations),
even in the hypothetical average-case O(1) situation.

Best.

K. Frank

1 Like