Conflict between `dataclass` and `nn.Module`

When I decorate a module class A that inherits nn.Module with dataclass, the main model class B (class A is the sub module) that also inherits nn.Module will report an error when calling the to(device) method. The error message is TypeError: unhashable type: 'A'. This problem can be reproduced in the following simple way. are there some possible ways to solve this issue?

@dataclass
class A(nn.Module):
    a: int
    b: int
    c: int

    def forward(self, x):
        return self.a * x + self.b * x + self.c
    
class B(nn.Module):
    def __init__(self, a, b, c):
        super().__init__()
        self.a = a
        self.b = b
        self.c = c
        self.sub_module = A(a, b, c)

    def forward(self, x):
        return self.a * x + self.b * x + self.c
    
model = B(1, 2, 3).to(torch.device('cuda'))
x = torch.tensor([1.0])
print(model(x))

You can get this to work with:

  • pass unsafe_hash=True to the dataclass decorator or override it
  • overriding the default init that dataclass generates to also call super init

It’s not super clear why you’d use dataclass though at this point since you’d have to override a lot of the default behavior anyway hash/init.

Also you might not even need to have A inherit nn.Module either since it appears to be stateless here.

I also use JAX, in which dataclass is very convenient and can be registered as a pytree to leverage autograd. I just wonder if PyTorch has adapted to dataclass.

I see, thanks for the context.

Autograd doesn’t have extensive pytree support like jax, yes, and part of it is for bc-reasons. it is easy to write wrappers that do the pytree handling though.

I just wonder if PyTorch has adapted to dataclass .

For the purpose of autograd, I feel like I still don’t completely understand why it benefits to make a module a dataclass, bc you wouldn’t be passing that module instance itself to autograd right?

Yes, as you said, I checked some information and it’s true that you shouldn’t pass a dataclass based module directly into the model, even in jax, it’s just to utilize the dataclass to better manage the model parameters. I got confused with these concepts, thank you very much for your answer, it helped me understand the difference.