Hi! I am having trouble type checking my modules. Here is an example:
# example.py
import torch
import torch.nn as nn
class Model(nn.Module):
def forward(self) -> int:
pass
model = Model()
reveal_type(model.forward)
reveal_type(model.__call__)
running mypy example.py
shows the following
example.py:12: note: Revealed type is 'def () -> builtins.int'
example.py:13: note: Revealed type is 'def (*Any, **Any) -> Any'
The first line is great, the second not so much.
Now if I use the model as a callable: out = model()
, reveal_type(out)
gets me Any
, which is problematic.
Could somebody show me how to properly typecheck my code?
Thanks in advance!