Efficient computation of median ignoring zeros (only torch operations)


GIven a 2D tensor, i’d like to compute the median for each row in that tensor in the most efficient way possible, but the caveat here is that I’d like to do this ignoring any zeros for the computation of the median.

So for example if my tensor is,

[ [3, 1, 2], 
  [0, 0, 1],
  [0, 0, 0] ]

After computing the median I would expect to get the tensor with output [2, 1, 0].

My current approach is as follows:

for (int i = 0; i < tensor.size(0); ++i) {  // go through each row in tensor
    noZeroInds = torch::nonzero(tensor[i].view(-1)).view(-1);
    outputTensor[i] = torch::median(tensor[i].index_select(0, noZeroInds))

So as you can see, I am going through each row in the tensor, selecting the nonZero indices, computing the median for the selected row, and finally storing this value to some output tensor.

So my question is, is there any way to do this in a more efficient vectorized fashion using torch operations?