Hello,
I would like to use torch.Tensor
as a type hint. E.g.
import torch
data: torch.Tensor = torch.tensor(2)
but if I do that I get:
foo.py:1: error: Cannot find implementation or library stub for module named "torch"
I saw there’s GitHub - patrick-kidger/jaxtyping: Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/ that can be used for even more in depth typing but I think for now a very basic type like the above would be sufficient for me.
So before I dig into jaxtyping: Is it possible to just use torch.Tensor
as a type?