How to use dataclass with PyTorch

@dataclass
class Net(nn.Module):

something like this?

I don’t use dataclass myself. Do you encounter any problem using it?

dir gives error

from dataclasses import dataclass

@dataclass
class Net(nn.Module):
  pass  

x = Net()

dir(x)

gives error

AttributeError: 'Net' object has no attribute '_parameters'

if I remove @dataclass, then it lists attributes.

It should work, if you initialize the parent class:

@dataclass
class Net(nn.Module):
  def __init__(self):
      super(Net, self).__init__()

x = Net()
dir(x) 

Although @ptrblck’s answer works, it kind of defeats the purpose of using a dataclass which is partly not writing the __init__ function yourself.

So here are some requirements to make this work:

  1. The pytorch module class (which is the dataclass itself) needs a __hash__ function. The __hash__ function is required in the named_modules function of nn.Module.
  2. We need to call super().__init__() at some point.
  3. The dataclass should not be frozen as the __init__ function of the nn.Module will try to set attributes. So you cannot use @dataclass(frozen=True) to have a __hash__ function for your dataclass.

The only solution I found, that is slightly better than @ptrblck’s answer, which I think will work is this:

@dataclass(unsafe_hash=True)
class Net(nn.Module):
    input_feats: int = 10
    output_feats: int = 20
    def __post_init__(self):
        super().__init__()
        self.layer = nn.Linear(self.input_feats, self.output_feats)

Notice the usage of __post_init__ and the ugly hack of setting unsafe_hash=True.

I wanted to use your solution for my project, however I got an issue when I have a nn.Module as argument:

import torch
import torch.nn as nn
from dataclasses import dataclass


class Evaluators(nn.Module):
    def __init__(self):
        super(Evaluators, self).__init__()
        self.linear = nn.Linear(1, 1)

@dataclass(unsafe_hash=True)
class Net(nn.Module):
    evaluator: Evaluators
    def __post_init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)

evaluators = Evaluators()
net = Net(evaluators )

returns:

test_dataclass.py:18 (test_dataclass)
def test_dataclass():
        evaluators = Evaluators()
>       net = Net(evaluators)

test_dataclass.py:21: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
<string>:3: in __init__
    ???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <[ModuleAttributeError("'Net' object has no attribute 'evaluator'") raised in repr()] Net object at 0x17d02449370>
name = 'evaluator'
value = Evaluators(
  (linear): Linear(in_features=1, out_features=1, bias=True)
)

    def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
        def remove_from(*dicts_or_sets):
            for d in dicts_or_sets:
                if name in d:
                    if isinstance(d, dict):
                        del d[name]
                    else:
                        d.discard(name)
    
        params = self.__dict__.get('_parameters')
        if isinstance(value, Parameter):
            if params is None:
                raise AttributeError(
                    "cannot assign parameters before Module.__init__() call")
            remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
            self.register_parameter(name, value)
        elif params is not None and name in params:
            if value is not None:
                raise TypeError("cannot assign '{}' as parameter '{}' "
                                "(torch.nn.Parameter or None expected)"
                                .format(torch.typename(value), name))
            self.register_parameter(name, value)
        else:
            modules = self.__dict__.get('_modules')
            if isinstance(value, Module):
                if modules is None:
>                   raise AttributeError(
                        "cannot assign module before Module.__init__() call")
E                   AttributeError: cannot assign module before Module.__init__() call

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py:807: AttributeError

Is there a way to solve it, or it just means that I won’t be able to use dataclass ?

I had not considered member variables of type “nn.Module”.
So the solution I proposed will not work in your setting.
I think a solution is possible but requires a bit more hacking into nn.Module or dataclass.

@yassersouri Please raise your usecase in the issue [discussion] Remove the need of mandatory super() module call · Issue #61686 · pytorch/pytorch · GitHub :slight_smile: With more comments, maybe the core team would consider it more worthy.

I think this can be remedied by the fact that __new__ effectively behaves like __pre_init__:

import torch as tr
import torch.nn as nn
from dataclasses import dataclass

@dataclass
class DataclassModule(nn.Module):
    def __new__(cls, *args, **k):
        inst = super().__new__(cls)
        nn.Module.__init__(inst)
        return inst

@dataclass(unsafe_hash=True)
class Net(DataclassModule):
    other_layer: nn.Module
    input_feats: int = 10
    output_feats: int = 20

    def __post_init__(self):
        self.layer = nn.Linear(self.input_feats, self.output_feats)

    def forward(self, x):
        return self.layer(self.other_layer(x))

net = Net(other_layer=nn.Linear(10, 10))
assert net(tr.tensor([1.]*10)).shape == (20,)
assert len(list(net.parameters())) == 4

@dataclass(unsafe_hash=True)
class A(DataclassModule):
    x: int
    def __post_init__(self):
        self.layer1 = nn.Linear(self.x, self.x)

@dataclass(unsafe_hash=True)
class B(A):
    y: int
    def __post_init__(self):
        super().__post_init__()
        self.layer2 = nn.Linear(self.y, self.y)

assert len(list(A(1).parameters())) == 2
assert len(list(B(1, 2).parameters())) == 4
1 Like

Have had some issues using this solution with nested DataclassModule derived classes. What I observed was some submodules were not transferring weights to the gpu when calling model.cuda(). Digging into it a bit deeper, the parameters of some submodules were not registered as parameters in the parent modules - which occurred when there were many instances of the same dataclass module in the model.

Using @dataclass(eq=False) instead of @dataclass(unsafe_hash=True) seems to resolve this. Here is a link to a related discussion:

https://stackoverflow.com/questions/57291307/pytorch-module-with-attrs-cannot-get-parameter-list