Adding typing to __call__ of nn.Module?

Hi, I want autocomplete and type checking working when calling into a customer nn.Module. Of course any parameters and types on forward don’t get mirrored to __call__. I’m wondering whether simply definining __call__, with the correct parameter names and types, and calling super().__call__ from that is a legitimate approach to this, or whether there might be some issues sometimes, and what those issue could be?

import torch
from torch import nn

class Foo(torch.nn.Module):
    def __init__(self):
        self.h = nn.Linear(1, 1)

    def __call__(self, param1: torch.Tensor) -> torch.Tensor:
        return super().__call__(param1)

    def forward(self, param1: torch.Tensor):
        param1 = self.h(param1)
        return param1

foo = Foo()
outputs = foo(param1=torch.rand(1))
print('outputs', outputs)

This code runs correctly, type-checks correctly etc. Thoughts?

I can’t think of a breaking use case at the moment, as hooks, nn.DataParallel, etc. should all work in your approach. This is of course not a valid verification, so let me think about some more use cases.

1 Like