Why not type annotate `Module` to be Tensor->Tensor?

Say I have

a = torch.zeros((3, ))
fc = nn.Linear(3, 6)
b = fc(a)

My IDE thinks a is a torch.Tensor but b is just Any. This forces me to do things like

b: torch.Tensor = fc(a)

a lot im my code.

Why didn’t pytorch annotate Module.__call__ to be -> torch.Tensor?

1 Like

I am also interested in this.
The line responsible for this behaviour seems to be here source code

I think it seems feasible to get type hints of function via get_type_hints from typing library.

One can use this to get return type of forward function defined by the users.

Maybe there is a way add this feature to pytorch without any problems ???

1 Like

Just discovered this related issue: Change the type hint for nn.Module.__call__ to be friendly to overrides. · Issue #74746 · pytorch/pytorch · GitHub

1 Like

Thanks for the reference to #74746.

To add to that, I would hope the built-in objects such as nn.Linear to come with Tensor->Tensor annotation as well, not just user-defined classes.

I tested with mypy, it detects nn.Linear just like the way you descibed, i.e. Tensor->Tensor since source code for built-in modules uses type hinting as well. Again, the problem/feature is the way IDEs recognize the types.

Relevant source code: here

1 Like

For nn.Linear, reveal_type returns Any, instead of torch.Tensor. To replicate this, run the code below

# main.py
import torch
from torch import nn

class LinWrapper(nn.Module):

    def __init__(self) -> None:
        self.lin = nn.Linear(12,2)

    def forward(
        self, x: torch.Tensor
    ) -> torch.Tensor:
        out = self.lin(x)
        return out

Bash command to run

mypy main.py


main.py:15: note: Revealed type is "Any"

Note when tested with custom module (with type hints), it reveals to be torch.Tensor.