Hi Max!
Yes, you can do this with pytorch tensor operations. Use cumsum()
divided by something like arange()
to get the running mean. Then
test that against your sorted tensor to find when your condition first
holds.
Here is a pytorch version 0.3.0 script:
import torch
torch.__version__
torch.manual_seed (2020)
x = torch.randn ((3, 2))
r = 2.0
# pytorch tensor operations
sorted, indices = torch.sort(torch.abs( x.view (-1,) ), descending=True) # sort data
running_mean = (torch.cumsum (sorted, 0) - r) / torch.arange (1, sorted.numel() + 1) # get cumulative mean
is_le = sorted <= running_mean # test condition
idx = is_le.numel() - is_le.sum() - 1 # get index
print ('idx =', idx)
print ('sorted[idx + 1] =', sorted[idx + 1])
print ('running_mean[idx + 1] =', running_mean[idx + 1])
# loop
# sorted, indices = torch.sort(torch.abs(torch.reshape(x, (-1,))), descending=True)
sorted, indices = torch.sort(torch.abs( x.view (-1,) ), descending=True)
running_sum = -r # r is some positive number
for idx, element in enumerate(sorted):
running_sum += element
if element <= (running_sum) / (idx + 1):
running_sum -= element
idx -= 1
break
print ('x =', x)
print ('sorted =', sorted)
print ('running_sum =', running_sum)
print ('idx =', idx)
print ('sorted[idx + 1] =', sorted[idx + 1])
print ('(running_sum + sorted[idx + 1]) / (idx + 2) =', (running_sum + sorted[idx + 1]) / (idx + 2))
And here is its output
>>> import torch
>>> torch.__version__
'0.3.0b0+591e73e'
>>>
>>> torch.manual_seed (2020)
<torch._C.Generator object at 0x0000020EC7916630>
>>>
>>> x = torch.randn ((3, 2))
>>> r = 2.0
>>>
>>> # pytorch tensor operations
... sorted, indices = torch.sort(torch.abs( x.view (-1,) ), descending=True) # sort data
>>> running_mean = (torch.cumsum (sorted, 0) - r) / torch.arange (1, sorted.numel() + 1) # get cumulative mean
>>> is_le = sorted <= running_mean # test condition
>>> idx = is_le.numel() - is_le.sum() - 1 # get index
>>> print ('idx =', idx)
idx = 3
>>> print ('sorted[idx + 1] =', sorted[idx + 1])
sorted[idx + 1] = 0.40786537528038025
>>> print ('running_mean[idx + 1] =', running_mean[idx + 1])
running_mean[idx + 1] = 0.6055201292037964
>>>
>>>
>>> # loop
... # sorted, indices = torch.sort(torch.abs(torch.reshape(x, (-1,))), descending=True)
... sorted, indices = torch.sort(torch.abs( x.view (-1,) ), descending=True)
>>> running_sum = -r # r is some positive number
>>> for idx, element in enumerate(sorted):
... running_sum += element
... if element <= (running_sum) / (idx + 1):
... running_sum -= element
... idx -= 1
... break
...
>>> print ('x =', x)
x =
1.2372 -0.9604
1.5415 -0.4079
0.8806 0.0529
[torch.FloatTensor of size 3x2]
>>> print ('sorted =', sorted)
sorted =
1.5415
1.2372
0.9604
0.8806
0.4079
0.0529
[torch.FloatTensor of size 6]
>>> print ('running_sum =', running_sum)
running_sum = 2.619735360145569
>>> print ('idx =', idx)
idx = 3
>>> print ('sorted[idx + 1] =', sorted[idx + 1])
sorted[idx + 1] = 0.40786537528038025
>>> print ('(running_sum + sorted[idx + 1]) / (idx + 2) =', (running_sum + sorted[idx + 1]) / (idx + 2))
(running_sum + sorted[idx + 1]) / (idx + 2) = 0.6055201470851899
Best.
K. Frank