Hi Everyone,
I do not want to use torch.where() and still wants to find the positive values in a tensor in a given dimension,
example, x = torch.randn(3,288), then I want to find the sum of positive values along the dimension 1. Note than x[x>0] will not be helpful because it will create an 1-d tensor.
Thanks in advance
KFrank
(K. Frank)
November 2, 2022, 3:12am
2
Hi Prakhar!
Prakhar_Pradhan1:
example, x = torch.randn(3,288), then I want to find the sum of positive values along the dimension 1.
In this particular instance, max()
suffices.
More generally, you can still use a condition such as x > 0
, but just use
it as a numerical mask, rather than an index.
Consider:
>>> import torch
>>> torch.__version__
'1.12.0'
>>> _ = torch.manual_seed (2022)
>>>
>>> x = torch.randn (3, 288)
>>> torch.max (x, torch.tensor ([0.0])).sum (dim = 1)
tensor([118.9255, 116.4891, 128.0370])
>>>
>>> mask = x > 0.0 # boolean mask, but can be used numerically
>>> (mask * x).sum (dim = 1)
tensor([118.9255, 116.4891, 128.0370])
>>>
>>> (mask * x).sum (dim = 1) / mask.sum (dim = 1) # mean of positive elements
tensor([0.8202, 0.8381, 0.8053])
Best.
K. Frank
KFrank:
(mask * x).sum (dim = 1)
Thank you so much, it will be of great help