how to get correct typing from nn.Module.__call__ ?

What I want: correct typing when I call nn modules

import torch
import torch.nn as nn
linear = nn.Linear(3, 4)
tensor = torch.ones(3)
out = linear(tensor)
reveal_type(out)  # Expected: torch.Tensor; Got: Any

IDEs (like VSCode or PyCharm) and linters (like mypy) don’t get expected typing.

I’ve found some workaround here, by adding typing to __call__ . This works fine when creating my own module, but I can’t do it for nn.Linear and all other modules in nn .

I also accidentally found that back in pytorch 1.4, the typing was correct, because .pyi files were provided back then and it actually utilized the __call__ trick above:

# pytorch 1.4
# torch/nn/modules/linear.pyi 
class Linear(Module):
    # several lines skipped ...
    def forward(self, input: Tensor) -> Tensor: ...  # type: ignore
    def __call__(self, input: Tensor) -> Tensor: ...  # type: ignore

Apparently the pyi file is removed now and according to this thread it’s inlined. But the inlined version seems not working.

I’m new to pytorch and that’s all I know, someone could tell me what went wrong?

Environment

  • python 3.9
  • pytorch 1.13

A relevant issue: issue link