Removing Zeros (between non-zero values) and maintaining the Tensor dimensions

Hi Guys,

I would like to remove zero values of a tensor and “join” the non-zero values in each row of a tensor in format [B, C, H, W]. A naive way would to do out_x = x[x!=0], this approach is bad because would destruct the Tensor dimensions. For resume, I would like to transform an input tensor like this:

in_x = torch.tensor([[[[0.,   0.,   2.,  20., 250.,   0.,   0.,   0.,   0.,   0., 250.,
                      20.,   2.,   0.,   0.,   0.],
                      [0.,   0.,   2.,  20.,   0.,   0.,  20., 250.,   0.,   0., 250.,
                      20.,  20.,   2.,   0.,   0.],
                      [0.,   0.,   0.,   0.,   0.,   0.,   2.,  20., 250.,   0., 250.,
                      20.,  20.,   2.,   0.,   0.],
                      [0.,   2.,  20.,   0.,  20., 250.,   0., 250.,  20.,   0.,  20.,
                       2.,   0.,   0.,   0.,   0.]]]])

In an output tensor like this:

out_x = torch.tensor([[[[  2.,  20., 250., 250.,  20.,   2.,   0.,   0.,   0.,   0.,   0.,
                          0.,   0.,   0.,   0.,   0.],
                       [  2.,  20.,  20., 250., 250., 20.,  20.,   2.,    0.,   0.,   0.,
                          0.,   0.,   0.,   0.,   0.],
                       [  2.,  20., 250., 250.,  20., 20.,   2.,   0.,   0.,   0.,   0.,
                          0.,   0.,   0.,   0.,   0.],
                       [  2.,  20.,  20., 250., 250.,  20.,  20.,   2., 0.,   0.,   0.,
                          0.,   0.,   0.,   0.,   0.]]]])

Note that both Tensors have the same shape: (1,1,4,16)

The solution must not contain For Loops or anything else that degrades performance.

Thanks.

Hi Luiz!

At the cost of n log (n) time complexity, you can use argsort()* and
then use gather() to index back into your input tensor:

>>> import torch
>>> torch.__version__
'1.9.0'
>>> in_x = torch.tensor([[[[0.,   0.,   2.,  20., 250.,   0.,   0.,   0.,   0.,   0., 250.,
...                       20.,   2.,   0.,   0.,   0.],
...                       [0.,   0.,   2.,  20.,   0.,   0.,  20., 250.,   0.,   0., 250.,
...                       20.,  20.,   2.,   0.,   0.],
...                       [0.,   0.,   0.,   0.,   0.,   0.,   2.,  20., 250.,   0., 250.,
...                       20.,  20.,   2.,   0.,   0.],
...                       [0.,   2.,  20.,   0.,  20., 250.,   0., 250.,  20.,   0.,  20.,
...                        2.,   0.,   0.,   0.,   0.]]]])
>>>
>>> in_x.gather (3, (in_x == 0.0).sort (dim = 3, stable = True)[1])
tensor([[[[  2.,  20., 250., 250.,  20.,   2.,   0.,   0.,   0.,   0.,   0.,
             0.,   0.,   0.,   0.,   0.],
          [  2.,  20.,  20., 250., 250.,  20.,  20.,   2.,   0.,   0.,   0.,
             0.,   0.,   0.,   0.,   0.],
          [  2.,  20., 250., 250.,  20.,  20.,   2.,   0.,   0.,   0.,   0.,
             0.,   0.,   0.,   0.,   0.],
          [  2.,  20.,  20., 250., 250.,  20.,  20.,   2.,   0.,   0.,   0.,
             0.,   0.,   0.,   0.,   0.]]]])

*) Except you have to use sort() because argsort() doesn’t support
stable = True).

Best.

K. Frank

1 Like

Thank you very much K. Frank, this works fine. Have a Merry Christmas and a Happy New Year!