What I want: correct typing when I call nn modules
import torch
import torch.nn as nn
linear = nn.Linear(3, 4)
tensor = torch.ones(3)
out = linear(tensor)
reveal_type(out) # Expected: torch.Tensor; Got: Any
IDEs (like VSCode or PyCharm) and linters (like mypy
) don’t get expected typing.
I’ve found some workaround here, by adding typing to __call__
. This works fine when creating my own module, but I can’t do it for nn.Linear
and all other modules in nn
.
I also accidentally found that back in pytorch 1.4, the typing was correct, because .pyi
files were provided back then and it actually utilized the __call__
trick above:
# pytorch 1.4
# torch/nn/modules/linear.pyi
class Linear(Module):
# several lines skipped ...
def forward(self, input: Tensor) -> Tensor: ... # type: ignore
def __call__(self, input: Tensor) -> Tensor: ... # type: ignore
Apparently the pyi file is removed now and according to this thread it’s inlined. But the inlined version seems not working.
I’m new to pytorch and that’s all I know, someone could tell me what went wrong?
Environment
- python 3.9
- pytorch 1.13