# Can this for loop be avoided using Pytorch operations?

Hey! I implemented the following algorithm using python for loops, which are, afaik, not very efficient:

``````sorted, indices = torch.sort(torch.abs(torch.reshape(x, (-1,))), descending=True)
running_sum = -r # r is some positive number
for idx, element in enumerate(sorted):
running_sum += element
if element <= (running_sum) / (idx + 1):
running_sum -= element
idx -= 1
break
``````

Is there any way to do this faster using Pytorch operations? The crucial part is finding the first entry of a large sorted tensor such that said entry is less-or-equal than the mean of all previous entries.

Looking forward for your ideas! Thanks a lot!

Hi Max!

Yes, you can do this with pytorch tensor operations. Use `cumsum()`
divided by something like `arange()` to get the running mean. Then
test that against your sorted tensor to find when your condition first
holds.

Here is a pytorch version 0.3.0 script:

``````import torch
torch.__version__

torch.manual_seed (2020)

x = torch.randn ((3, 2))
r = 2.0

# pytorch tensor operations
sorted, indices = torch.sort(torch.abs( x.view (-1,) ), descending=True)               # sort data
running_mean = (torch.cumsum (sorted, 0) - r) / torch.arange (1, sorted.numel() + 1)   # get cumulative mean
is_le = sorted <= running_mean          # test condition
idx = is_le.numel() - is_le.sum() - 1   # get index
print ('idx =', idx)
print ('sorted[idx + 1] =', sorted[idx + 1])
print ('running_mean[idx + 1] =', running_mean[idx + 1])

# loop
# sorted, indices = torch.sort(torch.abs(torch.reshape(x, (-1,))), descending=True)
sorted, indices = torch.sort(torch.abs( x.view (-1,) ), descending=True)
running_sum = -r # r is some positive number
for idx, element in enumerate(sorted):
running_sum += element
if element <= (running_sum) / (idx + 1):
running_sum -= element
idx -= 1
break

print ('x =', x)
print ('sorted =', sorted)
print ('running_sum =', running_sum)
print ('idx =', idx)
print ('sorted[idx + 1] =', sorted[idx + 1])
print ('(running_sum + sorted[idx + 1]) / (idx + 2) =', (running_sum + sorted[idx + 1]) / (idx + 2))
``````

And here is its output

``````>>> import torch
>>> torch.__version__
'0.3.0b0+591e73e'
>>>
>>> torch.manual_seed (2020)
<torch._C.Generator object at 0x0000020EC7916630>
>>>
>>> x = torch.randn ((3, 2))
>>> r = 2.0
>>>
>>> # pytorch tensor operations
... sorted, indices = torch.sort(torch.abs( x.view (-1,) ), descending=True)               # sort data
>>> running_mean = (torch.cumsum (sorted, 0) - r) / torch.arange (1, sorted.numel() + 1)   # get cumulative mean
>>> is_le = sorted <= running_mean          # test condition
>>> idx = is_le.numel() - is_le.sum() - 1   # get index
>>> print ('idx =', idx)
idx = 3
>>> print ('sorted[idx + 1] =', sorted[idx + 1])
sorted[idx + 1] = 0.40786537528038025
>>> print ('running_mean[idx + 1] =', running_mean[idx + 1])
running_mean[idx + 1] = 0.6055201292037964
>>>
>>>
>>> # loop
... # sorted, indices = torch.sort(torch.abs(torch.reshape(x, (-1,))), descending=True)
... sorted, indices = torch.sort(torch.abs( x.view (-1,) ), descending=True)
>>> running_sum = -r # r is some positive number
>>> for idx, element in enumerate(sorted):
...     running_sum += element
...     if element <= (running_sum) / (idx + 1):
...          running_sum -= element
...          idx -= 1
...          break
...
>>> print ('x =', x)
x =
1.2372 -0.9604
1.5415 -0.4079
0.8806  0.0529
[torch.FloatTensor of size 3x2]

>>> print ('sorted =', sorted)
sorted =
1.5415
1.2372
0.9604
0.8806
0.4079
0.0529
[torch.FloatTensor of size 6]

>>> print ('running_sum =', running_sum)
running_sum = 2.619735360145569
>>> print ('idx =', idx)
idx = 3
>>> print ('sorted[idx + 1] =', sorted[idx + 1])
sorted[idx + 1] = 0.40786537528038025
>>> print ('(running_sum + sorted[idx + 1]) / (idx + 2) =', (running_sum + sorted[idx + 1]) / (idx + 2))
(running_sum + sorted[idx + 1]) / (idx + 2) = 0.6055201470851899
``````

Best.

K. Frank

1 Like

Thanks for the fast and intuitive answer! One question remains: I was kind of surprised by your way to find `idx`. Is that really the most efficient way of finding the index of the largest False element of `is_le`? `numel()` is probably constant in the size of `is_le`, but `sum()` should be linear?

Hi Max!

First, yes, `sum()` is linear. (But, as an aside, finding the first `True`
index is also linear, but see below).

The key point is that `sum()` is implemented as a pytorch tensor
operation, so you get that potentially important benefit.

Note, many libraries offer a `find_first` type of function for searching
strings or arrays. Iâ€™m not aware that pytorch has such (at least not
in my old version). One could consider using pytorchâ€™s `argmax()`
function, but in this case several elements of `is_le` take on the
maximum value, and I donâ€™t think that pytorch guarantees that
`argmax()` will return the index of the first maximum element.

Now (other than the initial sort) this entire algorithm is linear. But
this algorithm looks at the entire `sorted` vector â€“ whether or not
it needs to. You only have to calculate your running mean until your
condition is satisfied. So if your condition is equally likely to be
satisfied anywhere in the vector, then, on average, you would only
have to look at half of `sorted` and compute only half of the running
mean. Cheaper, but still linear.

If your condition were strongly-enough biased to occur towards the
beginning of your vector, the algorithm could even become order 1
(â€śO(1)â€ť), for the average case. But, given that pytorch doesnâ€™t have
anything built into it that does this for you, youâ€™re almost certainly
better of using the linear â€“ but optimized â€“ pytorch tensor operations,
than writing your own non-tensor loop (even in the hypothetical
average-case O(1) example).

The same comments apply to finding the first `True` index in `is_le`.
This, in general, is linear, but could be cheaper in operations if you
run a loop that stops â€śearlyâ€ť when it finds that index. But the tensor
operation `sum()` is likely to be cheaper in time (but not operations),
even in the hypothetical average-case O(1) situation.

Best.

K. Frank

1 Like