# Removing Zeros (between non-zero values) and maintaining the Tensor dimensions

Hi Guys,

I would like to remove zero values of a tensor and “join” the non-zero values in each row of a tensor in format [B, C, H, W]. A naive way would to do `out_x = x[x!=0]`, this approach is bad because would destruct the Tensor dimensions. For resume, I would like to transform an input tensor like this:

``````in_x = torch.tensor([[[[0.,   0.,   2.,  20., 250.,   0.,   0.,   0.,   0.,   0., 250.,
20.,   2.,   0.,   0.,   0.],
[0.,   0.,   2.,  20.,   0.,   0.,  20., 250.,   0.,   0., 250.,
20.,  20.,   2.,   0.,   0.],
[0.,   0.,   0.,   0.,   0.,   0.,   2.,  20., 250.,   0., 250.,
20.,  20.,   2.,   0.,   0.],
[0.,   2.,  20.,   0.,  20., 250.,   0., 250.,  20.,   0.,  20.,
2.,   0.,   0.,   0.,   0.]]]])
``````

In an output tensor like this:

``````out_x = torch.tensor([[[[  2.,  20., 250., 250.,  20.,   2.,   0.,   0.,   0.,   0.,   0.,
0.,   0.,   0.,   0.,   0.],
[  2.,  20.,  20., 250., 250., 20.,  20.,   2.,    0.,   0.,   0.,
0.,   0.,   0.,   0.,   0.],
[  2.,  20., 250., 250.,  20., 20.,   2.,   0.,   0.,   0.,   0.,
0.,   0.,   0.,   0.,   0.],
[  2.,  20.,  20., 250., 250.,  20.,  20.,   2., 0.,   0.,   0.,
0.,   0.,   0.,   0.,   0.]]]])
``````

Note that both Tensors have the same shape: (1,1,4,16)

The solution must not contain For Loops or anything else that degrades performance.

Thanks.

Hi Luiz!

At the cost of `n log (n)` time complexity, you can use `argsort()`* and
then use `gather()` to index back into your input tensor:

``````>>> import torch
>>> torch.__version__
'1.9.0'
>>> in_x = torch.tensor([[[[0.,   0.,   2.,  20., 250.,   0.,   0.,   0.,   0.,   0., 250.,
...                       20.,   2.,   0.,   0.,   0.],
...                       [0.,   0.,   2.,  20.,   0.,   0.,  20., 250.,   0.,   0., 250.,
...                       20.,  20.,   2.,   0.,   0.],
...                       [0.,   0.,   0.,   0.,   0.,   0.,   2.,  20., 250.,   0., 250.,
...                       20.,  20.,   2.,   0.,   0.],
...                       [0.,   2.,  20.,   0.,  20., 250.,   0., 250.,  20.,   0.,  20.,
...                        2.,   0.,   0.,   0.,   0.]]]])
>>>
>>> in_x.gather (3, (in_x == 0.0).sort (dim = 3, stable = True)[1])
tensor([[[[  2.,  20., 250., 250.,  20.,   2.,   0.,   0.,   0.,   0.,   0.,
0.,   0.,   0.,   0.,   0.],
[  2.,  20.,  20., 250., 250.,  20.,  20.,   2.,   0.,   0.,   0.,
0.,   0.,   0.,   0.,   0.],
[  2.,  20., 250., 250.,  20.,  20.,   2.,   0.,   0.,   0.,   0.,
0.,   0.,   0.,   0.,   0.],
[  2.,  20.,  20., 250., 250.,  20.,  20.,   2.,   0.,   0.,   0.,
0.,   0.,   0.,   0.,   0.]]]])
``````

*) Except you have to use `sort()` because `argsort()` doesn’t support
`stable = True`).

Best.

K. Frank

1 Like

Thank you very much K. Frank, this works fine. Have a Merry Christmas and a Happy New Year!