How to use dataclass with PyTorch

@dataclass
class Net(nn.Module):

something like this?

I don’t use dataclass myself. Do you encounter any problem using it?

dir gives error

from dataclasses import dataclass

@dataclass
class Net(nn.Module):
  pass  

x = Net()

dir(x)

gives error

AttributeError: 'Net' object has no attribute '_parameters'

if I remove @dataclass, then it lists attributes.

It should work, if you initialize the parent class:

@dataclass
class Net(nn.Module):
  def __init__(self):
      super(Net, self).__init__()

x = Net()
dir(x) 

Although @ptrblck’s answer works, it kind of defeats the purpose of using a dataclass which is partly not writing the __init__ function yourself.

So here are some requirements to make this work:

  1. The pytorch module class (which is the dataclass itself) needs a __hash__ function. The __hash__ function is required in the named_modules function of nn.Module.
  2. We need to call super().__init__() at some point.
  3. The dataclass should not be frozen as the __init__ function of the nn.Module will try to set attributes. So you cannot use @dataclass(frozen=True) to have a __hash__ function for your dataclass.

The only solution I found, that is slightly better than @ptrblck’s answer, which I think will work is this:

@dataclass(unsafe_hash=True)
class Net(nn.Module):
    input_feats: int = 10
    output_feats: int = 20
    def __post_init__(self):
        super().__init__()
        self.layer = nn.Linear(self.input_feats, self.output_feats)

Notice the usage of __post_init__ and the ugly hack of setting unsafe_hash=True.

I wanted to use your solution for my project, however I got an issue when I have a nn.Module as argument:

import torch
import torch.nn as nn
from dataclasses import dataclass


class Evaluators(nn.Module):
    def __init__(self):
        super(Evaluators, self).__init__()
        self.linear = nn.Linear(1, 1)

@dataclass(unsafe_hash=True)
class Net(nn.Module):
    evaluator: Evaluators
    def __post_init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)

evaluators = Evaluators()
net = Net(evaluators )

returns:

test_dataclass.py:18 (test_dataclass)
def test_dataclass():
        evaluators = Evaluators()
>       net = Net(evaluators)

test_dataclass.py:21: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
<string>:3: in __init__
    ???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <[ModuleAttributeError("'Net' object has no attribute 'evaluator'") raised in repr()] Net object at 0x17d02449370>
name = 'evaluator'
value = Evaluators(
  (linear): Linear(in_features=1, out_features=1, bias=True)
)

    def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
        def remove_from(*dicts_or_sets):
            for d in dicts_or_sets:
                if name in d:
                    if isinstance(d, dict):
                        del d[name]
                    else:
                        d.discard(name)
    
        params = self.__dict__.get('_parameters')
        if isinstance(value, Parameter):
            if params is None:
                raise AttributeError(
                    "cannot assign parameters before Module.__init__() call")
            remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
            self.register_parameter(name, value)
        elif params is not None and name in params:
            if value is not None:
                raise TypeError("cannot assign '{}' as parameter '{}' "
                                "(torch.nn.Parameter or None expected)"
                                .format(torch.typename(value), name))
            self.register_parameter(name, value)
        else:
            modules = self.__dict__.get('_modules')
            if isinstance(value, Module):
                if modules is None:
>                   raise AttributeError(
                        "cannot assign module before Module.__init__() call")
E                   AttributeError: cannot assign module before Module.__init__() call

C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py:807: AttributeError

Is there a way to solve it, or it just means that I won’t be able to use dataclass ?

I had not considered member variables of type “nn.Module”.
So the solution I proposed will not work in your setting.
I think a solution is possible but requires a bit more hacking into nn.Module or dataclass.

@yassersouri Please raise your usecase in the issue [discussion] Remove the need of mandatory super() module call · Issue #61686 · pytorch/pytorch · GitHub :slight_smile: With more comments, maybe the core team would consider it more worthy.