I’m trying to annotate subclasses of nn.Module
inline, but for now I unable to get it to work. For my project I create an abstract subclass of nn.Module
from typing import Any
from torch import nn
class Foo(nn.Module):
def forward(self, *input: Any, **kwargs: Any) -> Any:
pass
Running mypy
on this succeeds. In a second step I add a more concrete class
import torch
class Bar(Foo):
def forward(self, input: torch.Tensor) -> torch.Tensor:
pass
Now mypy
errors with
error: Signature of "forward" incompatible with supertype "Foo"
error: Signature of "forward" incompatible with supertype "Module"
while pointing to forward()
of Bar
both times. I’m aware that this is valid error since Bar
violates the Liskov Substitution Principle. Looking at the stub of nn.Module
I think the marked paragraph is related, but for now I was unable to comprehend it.
But since I think it is not intended for every subclass of nn.Module
to have that exact signature, I’m puzzled how to get around this. Can someone help me out here?