I think this can be remedied by the fact that __new__
effectively behaves like __pre_init__
:
import torch as tr
import torch.nn as nn
from dataclasses import dataclass
@dataclass
class DataclassModule(nn.Module):
def __new__(cls, *args, **k):
inst = super().__new__(cls)
nn.Module.__init__(inst)
return inst
@dataclass(unsafe_hash=True)
class Net(DataclassModule):
other_layer: nn.Module
input_feats: int = 10
output_feats: int = 20
def __post_init__(self):
self.layer = nn.Linear(self.input_feats, self.output_feats)
def forward(self, x):
return self.layer(self.other_layer(x))
net = Net(other_layer=nn.Linear(10, 10))
assert net(tr.tensor([1.]*10)).shape == (20,)
assert len(list(net.parameters())) == 4
@dataclass(unsafe_hash=True)
class A(DataclassModule):
x: int
def __post_init__(self):
self.layer1 = nn.Linear(self.x, self.x)
@dataclass(unsafe_hash=True)
class B(A):
y: int
def __post_init__(self):
super().__post_init__()
self.layer2 = nn.Linear(self.y, self.y)
assert len(list(A(1).parameters())) == 2
assert len(list(B(1, 2).parameters())) == 4