Hi, I want autocomplete and type checking working when calling into a customer nn.Module. Of course any parameters and types on forward
don’t get mirrored to __call__
. I’m wondering whether simply definining __call__
, with the correct parameter names and types, and calling super().__call__
from that is a legitimate approach to this, or whether there might be some issues sometimes, and what those issue could be?
import torch
from torch import nn
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.h = nn.Linear(1, 1)
def __call__(self, param1: torch.Tensor) -> torch.Tensor:
return super().__call__(param1)
def forward(self, param1: torch.Tensor):
print('Foo.forward')
param1 = self.h(param1)
return param1
foo = Foo()
outputs = foo(param1=torch.rand(1))
print('outputs', outputs)
This code runs correctly, type-checks correctly etc. Thoughts?