How do I check whether a tensor is a float object without checking for all the specific dtypes (float16,float32,double). All I want, is to check that the object is some kind of float such that I can perform floating operations on it. I was expecting there to be some kind of abstract base class above all float types that I could check. Something like:
The best I can think of is some hackery using the string representation
of the tensor’s dtype. (Tensors of different dtypes are instances of
the same class, namely torch.Tensor, so you can’t use the type of
the tensor – in the sense of the class type of the tensor instance – to
determine the tensor’s dtype.)
Classes like torch.FloatTensor and torch.DoubleTensor aren’t really
used anymore*.
*) I suspect that in older versions of pytorch things like FloatTensor did
exist as bona fide classes and presumably were subclasses of some sort
of torch.BaseTensor. I further suspect that for reasons of backward
compatibility when in current versions of pytorch you “construct” a FloatTensor (as I did in the above example), the operation gets
delegated to some sort of factory function such as torch.tensor()
(note the lower-case tensor).
I ended up using @ptrblck suggestion, but your suggestion helped me understand the context, so thank you very much.
I guess all this also means that there is no good way to give typehints for torch tensors to be torch.floats or torch.ints, except maybe making a custom type union of all the float types manually and having that saved somewhere in the code. I’m surprised if this is the case, is there a particular reason why type-hinting like that isn’t used in pytorch?