RecursionError calling super().__call__ in forward

I want to subclass Sequential to create a Residual sequential class in a few lines.

class Residual(nn.Sequential):
    def forward(self, x):
        return x + super().__call__(x)

However, this crashes with:

RecursionError: maximum recursion depth exceeded while calling a Python object

If I understand well, the super().__call__ actually calls nn.Module.__call__ which then calls the forward of self … which is the one I overrode, that calls the __call__ again …

Changing the __call__ to forward solves the problem :

class Residual(nn.Sequential):
    def forward(self, x):
        return x + super().forward(x)

But I will lose the hooks then, is that right ?
Is my global understanding correct ?
Could it poses any problem to modules such as DataParallel that the forward of Sequential is callled directly ?

Thank you much in advance !

The simplest solution I can come up with is the following

class Residual(nn.Module):
    def __init__(self, *mods):
        super().__init__()
        self.seq = nn.Sequential(*mods)
       
    def forward(self, x):
        return x + self.seq(x)

This should retain all the hooks with 3 extra lines of code!

1 Like

Hi,

The hooks for this module have already been called before entering the forward of your module.
So you should use super().forward(x).