When I decorate a module class A that inherits nn.Module with dataclass, the main model class B (class A is the sub module) that also inherits nn.Module will report an error when calling the to(device) method. The error message is TypeError: unhashable type: 'A'. This problem can be reproduced in the following simple way. are there some possible ways to solve this issue?
@dataclass
class A(nn.Module):
a: int
b: int
c: int
def forward(self, x):
return self.a * x + self.b * x + self.c
class B(nn.Module):
def __init__(self, a, b, c):
super().__init__()
self.a = a
self.b = b
self.c = c
self.sub_module = A(a, b, c)
def forward(self, x):
return self.a * x + self.b * x + self.c
model = B(1, 2, 3).to(torch.device('cuda'))
x = torch.tensor([1.0])
print(model(x))
I also use JAX, in which dataclass is very convenient and can be registered as a pytree to leverage autograd. I just wonder if PyTorch has adapted to dataclass.
Autograd doesn’t have extensive pytree support like jax, yes, and part of it is for bc-reasons. it is easy to write wrappers that do the pytree handling though.
I just wonder if PyTorch has adapted to dataclass .
For the purpose of autograd, I feel like I still don’t completely understand why it benefits to make a module a dataclass, bc you wouldn’t be passing that module instance itself to autograd right?
Yes, as you said, I checked some information and it’s true that you shouldn’t pass a dataclass based module directly into the model, even in jax, it’s just to utilize the dataclass to better manage the model parameters. I got confused with these concepts, thank you very much for your answer, it helped me understand the difference.