When using type annotations in Python, what base type should we use for an argument that can either be a list, tuple, torch.Tensor, numpy.ndarray, etc.?
I would have expected the following to return true, since torch.Tensor implements both __len__ and __getitem__ methods:
from collections.abc import Sequence
import torch
a = torch.tensor((3, 4, 5.0))
isinstance(a, Sequence)
>>> False
Great question. I wonder if I can rely on it never going to be a sequence, cause I want to differentiate between a list of tensors (or rather a sequence of tensors) or a tensor.
If I do
if isinstance(t, Sequence):
...
elif isinstance (t, torch.Tensor):
...
this works now, but if tensors are one day going to be sequences, it’ll break.
Perhaps a simpler illustration of the problem is this:
>>> len(torch.tensor([0]))
1
>>> len(torch.tensor(0))
TypeError: len() of a 0-d tensor
>>> [*torch.tensor([0])]
[tensor(0)]
>>> [*torch.tensor(0)]
TypeError: iteration over a 0-d tensor
That is, I suppose an n-dim tensor could be a Sequence but a 0-dim tensor can’t.
This would be obvious with a type like List[int], but unfortunately it’s not so obvious with pytorch’s Tensor type.
In an alternative world we’d have Tensor[int] for multi-dimensional tensors and Prim[int] for 0-dimensional primitives, and then Tensor would be able to support the Sequence interface.
At least, I’m assuming that if they inherit collections.abc.Sequence then they would want to make sure that it can be statically type checked as a typing.Sequence and not throw any unexpected run-time exceptions.