Delete all zero'ed rows from a tensor

Hello everyone!

I have a tensor, let’s say A, which contains zero rows. I mean rows with all zero elements. There are other rows with non-zero elements as well. I there any elegant way to remove all zero’ed rows? To shed a light on the question, suppose the following example.

A=
[ [ [0,0,0,0,0], [0,-0.4,0,-0.9,0], [0,-0.7,0.9,0.2,0], [0,0,0,0,0] ] ]

What I want to get (let’s say tensor B) is the following:

B=
[ [ [0,-0.4,0,-0.9,0], [0,-0.7,0.9,0.2,0] ] ]

A = torch.Tensor([ [ [0,0,0,0,0], [0,-0.4,0,-0.9,0], [0,-0.7,0.9,0.2,0], [0,0,0,0,0] ] ])
B = A[A.sum(dim=2) != 0]
tensor([[ 0.0000, -0.4000, 0.0000, -0.9000, 0.0000],
[ 0.0000, -0.7000, 0.9000, 0.2000, 0.0000]])

3 Likes

Hi Sunny (and Sajad)!

That’s a nice solution.

Note, you should add an absolute-value in there to avoid deleting
a non-zero row that sums to zero, such as [1.0, -1.0, 0.0].

Best.

K. Frank

2 Likes

Yeah, that’s right. I forgot it.

Hi,

Thanks for the answer. Also, how can we get the list of indices corresponding to the “zero rows” ?
Like for above mentioned question it would be like: [0,3]. Thanks!!

edit: I actually want to remove the corresponding columns as well for which rows were zero and removed. So if 1st and 3rd row in a 2D tensor is zero, I want to remove the 1st and 3rd column as well.

Hi Suraj!

This question is a little simpler if the original tensor is square, that is,
it has the same number or rows as columns. (Not required, but you
don’t have to fiddle around checking sizes.)

Here’s the most compact approach I could find, illustrated as a
pytorch version 0.3.0 script:

import torch
torch.__version__

# use a version of A that is square
# note, A has a leading size = 1 dimension

A = torch.FloatTensor ([ [ [0,0,0,0], [0,-0.4,0,-0.9], [0,-0.7,0.9,0.2], [0,0,0,0] ] ])
rows = (A.sum (dim = 2) != 0).squeeze()

# delete the rows and corresponding columns

C = A[torch.ger (rows, rows)].view ((1, rows.sum(), rows.sum()))
A
C

# get the non-zero-row indices

torch.arange (len (rows)).long()[rows]

# get the zero-row indices

torch.arange (len (rows)).long()[~rows]

Here is the output:

>>> import torch
>>> torch.__version__
'0.3.0b0+591e73e'
>>>
>>> # use a version of A that is square
... # note, A has a leading size = 1 dimension
...
>>> A = torch.FloatTensor ([ [ [0,0,0,0], [0,-0.4,0,-0.9], [0,-0.7,0.9,0.2], [0,0,0,0] ] ])
>>> rows = (A.sum (dim = 2) != 0).squeeze()
>>>
>>> # delete the rows and corresponding columns
...
>>> C = A[torch.ger (rows, rows)].view ((1, rows.sum(), rows.sum()))
>>> A

(0 ,.,.) =
  0.0000  0.0000  0.0000  0.0000
  0.0000 -0.4000  0.0000 -0.9000
  0.0000 -0.7000  0.9000  0.2000
  0.0000  0.0000  0.0000  0.0000
[torch.FloatTensor of size 1x4x4]

>>> C

(0 ,.,.) =
 -0.4000  0.0000
 -0.7000  0.9000
[torch.FloatTensor of size 1x2x2]

>>>
>>> # get the non-zero-row indices
...
>>> torch.arange (len (rows)).long()[rows]

 1
 2
[torch.LongTensor of size 2]

>>>
>>> # get the zero-row indices
...
>>> torch.arange (len (rows)).long()[~rows]

 0
 3
[torch.LongTensor of size 2]

I would do it by indexing into a range, as illustrated in the above script.

Best.

K. Frank

1 Like

Thank you @KFrank for taking some time out to answer. I am getting an error while using torch.ger().
Error message:

RuntimeError: _th_addr_out not supported on CPUType for Bool

Could it be because of pytorch version? I have the latest one installed.

I actually had a naive (or a long) approach to it:

import torch
A = torch.FloatTensor( [ [0,0,0,0], [0,-0.4,0,-0.9], [0,-0.7,0.9,0.2], [0,0,0,0] ] )
idx = []
idx_removed = []
for i in range(4):
    flag = 0
    for j in range(4):
        if A[i,j] != 0:
            flag = 1
    if flag == 1:
        idx.append(i)
    else:
        idx_removed.append(i)

count = len(idx_removed)

A = A[idx,:]
A = A[:,idx]
A

count gives the number of zero_rows and idx gives the rows and columns to keep and idx_removed are indices to be removed. and if you just want number of zero rows then:

count = np.count_nonzero(torch.sum(torch.abs(A),dim = -1)==0)

Sorry for the making it this long as I have switched from C++ to python not very long ago. Your approach is much more time efficient, right? Thanks!

Hi Suraj!

I’m guessing, but yes, it is probably the pytorch version. Unfortunately
I don’t have an up-to-date version installed to test with.

For me (version 0.3.0) rows, as returned by
rows = (A.sum (dim = 2) != 0).squeeze()
is a ByteTensor, while your error message mentions Bool.

From recollection, that is a post-0.3.0 change.

Anyway (guessing) try converting rows to a LongTensor before
calling ger(), e.g.:

C = A[torch.ger (rows.long(), rows.long())].view ((1, rows.sum(), rows.sum()))

As a general (and usually reliable) rule, doing things solely with
pytorch tensor operations will be much more efficient than writing
a python loop that loops over elements (or rows / columns / slices)
of pytorch tensors.

If this happens inside of your training loop, for example, in your loss
function where you want to backpropagate through the results of
the loop computation, it can matter a lot. If you’re doing it on a
one-shot basis, for example, to prepare some data prior to training,
you might not care.

Best.

K. Frank

1 Like