Hi Max!

Yes, you can do this with pytorch tensor operations. Use `cumsum()`

divided by something like `arange()`

to get the running mean. Then

test that against your sorted tensor to find when your condition first

holds.

Here is a pytorch version 0.3.0 script:

```
import torch
torch.__version__
torch.manual_seed (2020)
x = torch.randn ((3, 2))
r = 2.0
# pytorch tensor operations
sorted, indices = torch.sort(torch.abs( x.view (-1,) ), descending=True) # sort data
running_mean = (torch.cumsum (sorted, 0) - r) / torch.arange (1, sorted.numel() + 1) # get cumulative mean
is_le = sorted <= running_mean # test condition
idx = is_le.numel() - is_le.sum() - 1 # get index
print ('idx =', idx)
print ('sorted[idx + 1] =', sorted[idx + 1])
print ('running_mean[idx + 1] =', running_mean[idx + 1])
# loop
# sorted, indices = torch.sort(torch.abs(torch.reshape(x, (-1,))), descending=True)
sorted, indices = torch.sort(torch.abs( x.view (-1,) ), descending=True)
running_sum = -r # r is some positive number
for idx, element in enumerate(sorted):
running_sum += element
if element <= (running_sum) / (idx + 1):
running_sum -= element
idx -= 1
break
print ('x =', x)
print ('sorted =', sorted)
print ('running_sum =', running_sum)
print ('idx =', idx)
print ('sorted[idx + 1] =', sorted[idx + 1])
print ('(running_sum + sorted[idx + 1]) / (idx + 2) =', (running_sum + sorted[idx + 1]) / (idx + 2))
```

And here is its output

```
>>> import torch
>>> torch.__version__
'0.3.0b0+591e73e'
>>>
>>> torch.manual_seed (2020)
<torch._C.Generator object at 0x0000020EC7916630>
>>>
>>> x = torch.randn ((3, 2))
>>> r = 2.0
>>>
>>> # pytorch tensor operations
... sorted, indices = torch.sort(torch.abs( x.view (-1,) ), descending=True) # sort data
>>> running_mean = (torch.cumsum (sorted, 0) - r) / torch.arange (1, sorted.numel() + 1) # get cumulative mean
>>> is_le = sorted <= running_mean # test condition
>>> idx = is_le.numel() - is_le.sum() - 1 # get index
>>> print ('idx =', idx)
idx = 3
>>> print ('sorted[idx + 1] =', sorted[idx + 1])
sorted[idx + 1] = 0.40786537528038025
>>> print ('running_mean[idx + 1] =', running_mean[idx + 1])
running_mean[idx + 1] = 0.6055201292037964
>>>
>>>
>>> # loop
... # sorted, indices = torch.sort(torch.abs(torch.reshape(x, (-1,))), descending=True)
... sorted, indices = torch.sort(torch.abs( x.view (-1,) ), descending=True)
>>> running_sum = -r # r is some positive number
>>> for idx, element in enumerate(sorted):
... running_sum += element
... if element <= (running_sum) / (idx + 1):
... running_sum -= element
... idx -= 1
... break
...
>>> print ('x =', x)
x =
1.2372 -0.9604
1.5415 -0.4079
0.8806 0.0529
[torch.FloatTensor of size 3x2]
>>> print ('sorted =', sorted)
sorted =
1.5415
1.2372
0.9604
0.8806
0.4079
0.0529
[torch.FloatTensor of size 6]
>>> print ('running_sum =', running_sum)
running_sum = 2.619735360145569
>>> print ('idx =', idx)
idx = 3
>>> print ('sorted[idx + 1] =', sorted[idx + 1])
sorted[idx + 1] = 0.40786537528038025
>>> print ('(running_sum + sorted[idx + 1]) / (idx + 2) =', (running_sum + sorted[idx + 1]) / (idx + 2))
(running_sum + sorted[idx + 1]) / (idx + 2) = 0.6055201470851899
```

Best.

K. Frank