__call__ signature of nn.Module

https://github.com/lab-ml/labml/blob/master/labml/helpers/pytorch/module.py

This is a class I wrote than subclass nn.Module and let modules implement __call__ instead of forward. I did this to get better type checking and type inference in python.

This module uses __init_subclass__ set the __call__ method of it’s subclasses to forward so that it doesn’t mess up with nn.Module hooks and stuff. I’ve been using this for a while and hasn’t encountered problems.

Are there implications that I haven’t seen?

Hi,

So this does not call the original __call__ anymore?

It does, basically we use __init_subclass__ to rename your implementation of __call__ to forward

But the original implementation of __call__ then calls forward. Doesn’t that lead to infinite recursion?

Sorry I din’t see your comment.

No it doesn’t call itself. Only the subclass implementation of call is renamed.

1 Like