How to correctly annotate subclasses of nn.Module?

I’m trying to annotate subclasses of nn.Module inline, but for now I unable to get it to work. For my project I create an abstract subclass of nn.Module

from typing import Any
from torch import nn


class Foo(nn.Module):
    def forward(self, *input: Any, **kwargs: Any) -> Any:
        pass

Running mypy on this succeeds. In a second step I add a more concrete class

import torch


class Bar(Foo):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        pass

Now mypy errors with

error: Signature of "forward" incompatible with supertype "Foo"
error: Signature of "forward" incompatible with supertype "Module" 

while pointing to forward() of Bar both times. I’m aware that this is valid error since Bar violates the Liskov Substitution Principle. Looking at the stub of nn.Module I think the marked paragraph is related, but for now I was unable to comprehend it.

But since I think it is not intended for every subclass of nn.Module to have that exact signature, I’m puzzled how to get around this. Can someone help me out here?

I’m seeing a similar problem. I have code that looks like:

import torch.nn as nn

class Foo(nn.Module):
    ...

but mypy errors out with:

error: Name "nn.Module" is not defined
error: Class cannot subclass "Module" (has type "Any")

No idea what this means or how to solve it.

1 Like