Say I have
a = torch.zeros((3, ))
fc = nn.Linear(3, 6)
b = fc(a)
My IDE thinks
a is a
b is just
Any. This forces me to do things like
b: torch.Tensor = fc(a)
a lot im my code.
Why didn’t pytorch annotate
Module.__call__ to be
I am also interested in this.
The line responsible for this behaviour seems to be here source code
I think it seems feasible to get type hints of function via get_type_hints from typing library.
One can use this to get return type of forward function defined by the users.
Maybe there is a way add this feature to pytorch without any problems ???
Thanks for the reference to #74746.
To add to that, I would hope the built-in objects such as
nn.Linear to come with
Tensor->Tensor annotation as well, not just user-defined classes.
I tested with mypy, it detects nn.Linear just like the way you descibed, i.e.
Tensor->Tensor since source code for built-in modules uses type hinting as well. Again, the problem/feature is the way IDEs recognize the types.
Relevant source code: here
nn.Linear, reveal_type returns
Any, instead of
torch.Tensor. To replicate this, run the code below
from torch import nn
def __init__(self) -> None:
self.lin = nn.Linear(12,2)
self, x: torch.Tensor
) -> torch.Tensor:
out = self.lin(x)
Bash command to run
main.py:15: note: Revealed type is "Any"
Note when tested with custom module (with type hints), it reveals to be