How to use dataclass with PyTorch

Although @ptrblck’s answer works, it kind of defeats the purpose of using a dataclass which is partly not writing the __init__ function yourself.

So here are some requirements to make this work:

  1. The pytorch module class (which is the dataclass itself) needs a __hash__ function. The __hash__ function is required in the named_modules function of nn.Module.
  2. We need to call super().__init__() at some point.
  3. The dataclass should not be frozen as the __init__ function of the nn.Module will try to set attributes. So you cannot use @dataclass(frozen=True) to have a __hash__ function for your dataclass.

The only solution I found, that is slightly better than @ptrblck’s answer, which I think will work is this:

@dataclass(unsafe_hash=True)
class Net(nn.Module):
    input_feats: int = 10
    output_feats: int = 20
    def __post_init__(self):
        super().__init__()
        self.layer = nn.Linear(self.input_feats, self.output_feats)

Notice the usage of __post_init__ and the ugly hack of setting unsafe_hash=True.