@dataclass
class Net(nn.Module):
…
something like this?
@dataclass
class Net(nn.Module):
…
something like this?
I don’t use dataclass myself. Do you encounter any problem using it?
dir gives error
from dataclasses import dataclass
@dataclass
class Net(nn.Module):
pass
x = Net()
dir(x)
gives error
AttributeError: 'Net' object has no attribute '_parameters'
if I remove @dataclass
, then it lists attributes.
It should work, if you initialize the parent class:
@dataclass
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
x = Net()
dir(x)
Although @ptrblck’s answer works, it kind of defeats the purpose of using a dataclass which is partly not writing the __init__
function yourself.
So here are some requirements to make this work:
__hash__
function. The __hash__
function is required in the named_modules
function of nn.Module
.super().__init__()
at some point.__init__
function of the nn.Module
will try to set attributes. So you cannot use @dataclass(frozen=True)
to have a __hash__
function for your dataclass.The only solution I found, that is slightly better than @ptrblck’s answer, which I think will work is this:
@dataclass(unsafe_hash=True)
class Net(nn.Module):
input_feats: int = 10
output_feats: int = 20
def __post_init__(self):
super().__init__()
self.layer = nn.Linear(self.input_feats, self.output_feats)
Notice the usage of __post_init__
and the ugly hack of setting unsafe_hash=True
.
I wanted to use your solution for my project, however I got an issue when I have a nn.Module as argument:
import torch
import torch.nn as nn
from dataclasses import dataclass
class Evaluators(nn.Module):
def __init__(self):
super(Evaluators, self).__init__()
self.linear = nn.Linear(1, 1)
@dataclass(unsafe_hash=True)
class Net(nn.Module):
evaluator: Evaluators
def __post_init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)
evaluators = Evaluators()
net = Net(evaluators )
returns:
test_dataclass.py:18 (test_dataclass)
def test_dataclass():
evaluators = Evaluators()
> net = Net(evaluators)
test_dataclass.py:21:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
<string>:3: in __init__
???
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <[ModuleAttributeError("'Net' object has no attribute 'evaluator'") raised in repr()] Net object at 0x17d02449370>
name = 'evaluator'
value = Evaluators(
(linear): Linear(in_features=1, out_features=1, bias=True)
)
def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
def remove_from(*dicts_or_sets):
for d in dicts_or_sets:
if name in d:
if isinstance(d, dict):
del d[name]
else:
d.discard(name)
params = self.__dict__.get('_parameters')
if isinstance(value, Parameter):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError("cannot assign '{}' as parameter '{}' "
"(torch.nn.Parameter or None expected)"
.format(torch.typename(value), name))
self.register_parameter(name, value)
else:
modules = self.__dict__.get('_modules')
if isinstance(value, Module):
if modules is None:
> raise AttributeError(
"cannot assign module before Module.__init__() call")
E AttributeError: cannot assign module before Module.__init__() call
C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py:807: AttributeError
Is there a way to solve it, or it just means that I won’t be able to use dataclass ?
I had not considered member variables of type “nn.Module”.
So the solution I proposed will not work in your setting.
I think a solution is possible but requires a bit more hacking into nn.Module
or dataclass
.
@yassersouri Please raise your usecase in the issue [discussion] Remove the need of mandatory super() module call · Issue #61686 · pytorch/pytorch · GitHub With more comments, maybe the core team would consider it more worthy.
I think this can be remedied by the fact that __new__
effectively behaves like __pre_init__
:
import torch as tr
import torch.nn as nn
from dataclasses import dataclass
@dataclass
class DataclassModule(nn.Module):
def __new__(cls, *args, **k):
inst = super().__new__(cls)
nn.Module.__init__(inst)
return inst
@dataclass(unsafe_hash=True)
class Net(DataclassModule):
other_layer: nn.Module
input_feats: int = 10
output_feats: int = 20
def __post_init__(self):
self.layer = nn.Linear(self.input_feats, self.output_feats)
def forward(self, x):
return self.layer(self.other_layer(x))
net = Net(other_layer=nn.Linear(10, 10))
assert net(tr.tensor([1.]*10)).shape == (20,)
assert len(list(net.parameters())) == 4
@dataclass(unsafe_hash=True)
class A(DataclassModule):
x: int
def __post_init__(self):
self.layer1 = nn.Linear(self.x, self.x)
@dataclass(unsafe_hash=True)
class B(A):
y: int
def __post_init__(self):
super().__post_init__()
self.layer2 = nn.Linear(self.y, self.y)
assert len(list(A(1).parameters())) == 2
assert len(list(B(1, 2).parameters())) == 4
Have had some issues using this solution with nested DataclassModule derived classes. What I observed was some submodules were not transferring weights to the gpu when calling model.cuda(). Digging into it a bit deeper, the parameters of some submodules were not registered as parameters in the parent modules - which occurred when there were many instances of the same dataclass module in the model.
Using @dataclass(eq=False)
instead of @dataclass(unsafe_hash=True)
seems to resolve this. Here is a link to a related discussion:
https://stackoverflow.com/questions/57291307/pytorch-module-with-attrs-cannot-get-parameter-list