What I want: correct typing when I call nn modules
import torch import torch.nn as nn linear = nn.Linear(3, 4) tensor = torch.ones(3) out = linear(tensor) reveal_type(out) # Expected: torch.Tensor; Got: Any
IDEs (like VSCode or PyCharm) and linters (like
mypy) don’t get expected typing.
I’ve found some workaround here, by adding typing to
__call__ . This works fine when creating my own module, but I can’t do it for
nn.Linear and all other modules in
I also accidentally found that back in pytorch 1.4, the typing was correct, because
.pyi files were provided back then and it actually utilized the
__call__ trick above:
# pytorch 1.4 # torch/nn/modules/linear.pyi class Linear(Module): # several lines skipped ... def forward(self, input: Tensor) -> Tensor: ... # type: ignore def __call__(self, input: Tensor) -> Tensor: ... # type: ignore
Apparently the pyi file is removed now and according to this thread it’s inlined. But the inlined version seems not working.
I’m new to pytorch and that’s all I know, someone could tell me what went wrong?
- python 3.9
- pytorch 1.13