in PyTorch 0.4.1, how can I check if a tensor contains boolean values?
Is the following code, the standard way?
t1 = torch.tensor([True, False, False])
print(t1.dtype == torch.uint8)
in PyTorch 0.4.1, how can I check if a tensor contains boolean values?
Is the following code, the standard way?
t1 = torch.tensor([True, False, False])
print(t1.dtype == torch.uint8)
Well no Tensor contains only boolean values. They will contain uint8
as you have seen. But they could contain numbers bigger than 1 as well.
You can do comething like (tensor >0.5 & tensor == 1 ) & (tensor <0.5 & tensor == 0)
Thanks for the confirmation!