Check if at least one element in tensor is nonzero

I want to check if a tensor contains nonzero elements. It does not matter to me how many nonzero elements there are. I currently do the following which is very inefficient for large tensors:

T.abs().sum().item() == 0

Is there a better way to do this in PyTorch? Like a function that stops iterating through the tensor as soon as it finds the first nonzero element?

Not too sure on that, but I do believe that this:

T.abs().sum().item() == 0

will be much faster than

a function that stops iterating through the tensor as soon as it finds the first nonzero element?

as loops/iterators are very very slow :open_mouth:

Also this might be of interest: .nonzero()

image

You can check some_tensor.nonzero() and check for empty results (i.e. tensor([])

1 Like

tensorName.type(t.uint8).any()

I guess tensor_name.any() should work fine. Why the type() call?

Without typecasting, I get this error:
RuntimeError: all only supports torch.uint8 and torch.bool dtypes
At least with pytorch version 1.7.1

Using type(torch.uint8) is extremely dangerous.

Let’s say you have a tensor of floats in the range -0.1 to 0.1.

Casting to uint8 will convert every value to 0. Which any() will see as FALSE (“there’s no values”).

Here’s an example of the bug your code causes, believing the tensor is totally empty because it was cast to uint8 zeroes:

>>> foo = torch.tensor([-0.1, 0.05, 0.3])
>>> foo.type(torch.uint8)
tensor([0, 0, 0], dtype=torch.uint8)
>>> foo.type(torch.uint8).any()
tensor(0, dtype=torch.uint8)
>>> bool(foo.type(torch.uint8).any())
False

By the way, I think the data type conversion totally defeats the purposes of “any()” to quickly check for any non-zero values, because it would waste time converting every value first. So in that case, it would be more efficient to just use your_tensor.count_nonzero() > 0, which does a slow count of every non-zero value but at least it’s faster and uses less memory than converting everything!

Anyway, it’s 2024 and I use Torch “2.3.1+cu118”.

It is totally capable of calling “.any()” on float tensors without any issues. NOTHING needs to be converted. Keep the native types!

>>> import torch
>>> x = torch.tensor([0.0, -0.5, 1.2, 50.3])
>>> y = torch.tensor([0.0, 0.0, 0.0, 0.0])
>>> z = torch.tensor([999.0, -0.5, 1.2, 50.3])
>>> x.any()
tensor(True)
>>> y.any()
tensor(False)
>>> z.any()
tensor(True)
>>> x.all()
tensor(False)
>>> y.all()
tensor(False)
>>> z.all()
tensor(True)

https://pytorch.org/docs/stable/generated/torch.any.html#torch.any

The code is available here for anyone curious what any() does for different backends: