Say I have
a = torch.zeros((3, ))
fc = nn.Linear(3, 6)
b = fc(a)
My IDE thinks a
is a torch.Tensor
but b
is just Any
. This forces me to do things like
b: torch.Tensor = fc(a)
a lot im my code.
Why didn’t pytorch annotate Module.__call__
to be -> torch.Tensor
?
1 Like
I am also interested in this.
The line responsible for this behaviour seems to be here source code
I think it seems feasible to get type hints of function via get_type_hints from typing library.
One can use this to get return type of forward function defined by the users.
Maybe there is a way add this feature to pytorch without any problems ???
1 Like
Thanks for the reference to #74746.
To add to that, I would hope the built-in objects such as nn.Linear
to come with Tensor->Tensor
annotation as well, not just user-defined classes.
I tested with mypy, it detects nn.Linear just like the way you descibed, i.e. Tensor->Tensor
since source code for built-in modules uses type hinting as well. Again, the problem/feature is the way IDEs recognize the types.
Relevant source code: here
1 Like
For nn.Linear
, reveal_type returns Any
, instead of torch.Tensor
. To replicate this, run the code below
pytorch=1.13
mypy=0.991
# main.py
import torch
from torch import nn
class LinWrapper(nn.Module):
def __init__(self) -> None:
super().__init__()
self.lin = nn.Linear(12,2)
def forward(
self, x: torch.Tensor
) -> torch.Tensor:
out = self.lin(x)
reveal_type(out)
return out
Bash command to run
mypy main.py
Output:
main.py:15: note: Revealed type is "Any"
Note when tested with custom module (with type hints), it reveals to be torch.Tensor
.