How to correctly annotate subclasses of nn.Module?

I’m trying to annotate subclasses of nn.Module inline, but for now I unable to get it to work. For my project I create an abstract subclass of nn.Module

from typing import Any
from torch import nn

class Foo(nn.Module):
    def forward(self, *input: Any, **kwargs: Any) -> Any:

Running mypy on this succeeds. In a second step I add a more concrete class

import torch

class Bar(Foo):
    def forward(self, input: torch.Tensor) -> torch.Tensor:

Now mypy errors with

error: Signature of "forward" incompatible with supertype "Foo"
error: Signature of "forward" incompatible with supertype "Module" 

while pointing to forward() of Bar both times. I’m aware that this is valid error since Bar violates the Liskov Substitution Principle. Looking at the stub of nn.Module I think the marked paragraph is related, but for now I was unable to comprehend it.

But since I think it is not intended for every subclass of nn.Module to have that exact signature, I’m puzzled how to get around this. Can someone help me out here?

I’m seeing a similar problem. I have code that looks like:

import torch.nn as nn

class Foo(nn.Module):

but mypy errors out with:

error: Name "nn.Module" is not defined
error: Class cannot subclass "Module" (has type "Any")

No idea what this means or how to solve it.

1 Like