Although @ptrblck’s answer works, it kind of defeats the purpose of using a dataclass which is partly not writing the __init__
function yourself.
So here are some requirements to make this work:
- The pytorch module class (which is the dataclass itself) needs a
__hash__
function. The__hash__
function is required in thenamed_modules
function ofnn.Module
. - We need to call
super().__init__()
at some point. - The dataclass should not be frozen as the
__init__
function of thenn.Module
will try to set attributes. So you cannot use@dataclass(frozen=True)
to have a__hash__
function for your dataclass.
The only solution I found, that is slightly better than @ptrblck’s answer, which I think will work is this:
@dataclass(unsafe_hash=True)
class Net(nn.Module):
input_feats: int = 10
output_feats: int = 20
def __post_init__(self):
super().__init__()
self.layer = nn.Linear(self.input_feats, self.output_feats)
Notice the usage of __post_init__
and the ugly hack of setting unsafe_hash=True
.