How to check if a tensor is a float/double?

How do I check whether a tensor is a float object without checking for all the specific dtypes (float16,float32,double). All I want, is to check that the object is some kind of float such that I can perform floating operations on it. I was expecting there to be some kind of abstract base class above all float types that I could check. Something like:


But I haven’t found anything.

Hi Tue!

The best I can think of is some hackery using the string representation
of the tensor’s dtype. (Tensors of different dtypes are instances of
the same class, namely torch.Tensor, so you can’t use the type of
the tensor – in the sense of the class type of the tensor instance – to
determine the tensor’s dtype.)

Classes like torch.FloatTensor and torch.DoubleTensor aren’t really
used anymore*.


>>> import torch
>>> print (torch.__version__)
>>> tf = torch.FloatTensor ((5.6,))
>>> td = torch.DoubleTensor ((5.6,))
>>> tl = torch.LongTensor ((5.6,))
>>> type (tf)
<class 'torch.Tensor'>
>>> type (tf).mro()
[<class 'torch.Tensor'>, <class 'torch._C._TensorBase'>, <class 'object'>]
>>> tf.dtype
>>> 'float' in str (tf.dtype)
>>> type (td)
<class 'torch.Tensor'>
>>> type (td).mro()
[<class 'torch.Tensor'>, <class 'torch._C._TensorBase'>, <class 'object'>]
>>> td.dtype
>>> 'float' in str (td.dtype)
>>> type (tl)
<class 'torch.Tensor'>
>>> type (tl).mro()
[<class 'torch.Tensor'>, <class 'torch._C._TensorBase'>, <class 'object'>]
>>> tl.dtype
>>> 'float' in str (tl.dtype)

*) I suspect that in older versions of pytorch things like FloatTensor did
exist as bona fide classes and presumably were subclasses of some sort
of torch.BaseTensor. I further suspect that for reasons of backward
compatibility when in current versions of pytorch you “construct” a
FloatTensor (as I did in the above example), the operation gets
delegated to some sort of factory function such as torch.tensor()
(note the lower-case tensor).


K. Frank

1 Like

Besides @KFrank’s checks you could use torch.is_floating_point as seen here:

for dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64, torch.complex32, torch.complex64, torch.complex128, torch.int32, torch.int64]:
    x = torch.randn(1).to(dtype)
    print("dtype {} is_floating_point {}".format(dtype, torch.is_floating_point(x)))

# dtype torch.float16 is_floating_point True
# dtype torch.bfloat16 is_floating_point True
# dtype torch.float32 is_floating_point True
# dtype torch.float64 is_floating_point True
# dtype torch.complex32 is_floating_point False
# dtype torch.complex64 is_floating_point False
# dtype torch.complex128 is_floating_point False
# dtype torch.int32 is_floating_point False
# dtype torch.int64 is_floating_point False
1 Like

I ended up using @ptrblck suggestion, but your suggestion helped me understand the context, so thank you very much.

I guess all this also means that there is no good way to give typehints for torch tensors to be torch.floats or torch.ints, except maybe making a custom type union of all the float types manually and having that saved somewhere in the code. I’m surprised if this is the case, is there a particular reason why type-hinting like that isn’t used in pytorch?