How to use dataclass with PyTorch

I think this can be remedied by the fact that __new__ effectively behaves like __pre_init__:

import torch as tr
import torch.nn as nn
from dataclasses import dataclass

@dataclass
class DataclassModule(nn.Module):
    def __new__(cls, *args, **k):
        inst = super().__new__(cls)
        nn.Module.__init__(inst)
        return inst

@dataclass(unsafe_hash=True)
class Net(DataclassModule):
    other_layer: nn.Module
    input_feats: int = 10
    output_feats: int = 20

    def __post_init__(self):
        self.layer = nn.Linear(self.input_feats, self.output_feats)

    def forward(self, x):
        return self.layer(self.other_layer(x))

net = Net(other_layer=nn.Linear(10, 10))
assert net(tr.tensor([1.]*10)).shape == (20,)
assert len(list(net.parameters())) == 4

@dataclass(unsafe_hash=True)
class A(DataclassModule):
    x: int
    def __post_init__(self):
        self.layer1 = nn.Linear(self.x, self.x)

@dataclass(unsafe_hash=True)
class B(A):
    y: int
    def __post_init__(self):
        super().__post_init__()
        self.layer2 = nn.Linear(self.y, self.y)

assert len(list(A(1).parameters())) == 2
assert len(list(B(1, 2).parameters())) == 4
1 Like