I am trying to create a python dataclass that supports both cpu and gpu DoubleTensor
:
@dataclass
class SomeDC:
element: Union[torch.DoubleTensor, torch.cuda.DoubleTensor]
However, this results in the mypy
error error: Name 'torch.cuda.DoubleTensor' is not defined
. I tried import torch.cuda
but the same error occurs. Anyone know if there is a workaround for this or do I just ignore this mypy
error?
Note, the code runs fine this way, the only problem is the mypy
error.