Using type(torch.uint8)
is extremely dangerous.
Let’s say you have a tensor of floats in the range -0.1
to 0.1
.
Casting to uint8 will convert every value to 0
. Which any()
will see as FALSE (“there’s no values”).
Here’s an example of the bug your code causes, believing the tensor is totally empty because it was cast to uint8 zeroes:
>>> foo = torch.tensor([-0.1, 0.05, 0.3])
>>> foo.type(torch.uint8)
tensor([0, 0, 0], dtype=torch.uint8)
>>> foo.type(torch.uint8).any()
tensor(0, dtype=torch.uint8)
>>> bool(foo.type(torch.uint8).any())
False
By the way, I think the data type conversion totally defeats the purposes of “any()” to quickly check for any non-zero values, because it would waste time converting every value first. So in that case, it would be more efficient to just use your_tensor.count_nonzero() > 0
, which does a slow count of every non-zero value but at least it’s faster and uses less memory than converting everything!
Anyway, it’s 2024 and I use Torch “2.3.1+cu118”.
It is totally capable of calling “.any()” on float tensors without any issues. NOTHING needs to be converted. Keep the native types!
>>> import torch
>>> x = torch.tensor([0.0, -0.5, 1.2, 50.3])
>>> y = torch.tensor([0.0, 0.0, 0.0, 0.0])
>>> z = torch.tensor([999.0, -0.5, 1.2, 50.3])
>>> x.any()
tensor(True)
>>> y.any()
tensor(False)
>>> z.any()
tensor(True)
>>> x.all()
tensor(False)
>>> y.all()
tensor(False)
>>> z.all()
tensor(True)
https://pytorch.org/docs/stable/generated/torch.any.html#torch.any
The code is available here for anyone curious what any()
does for different backends: