How to implement mixin in PyTorch

How should EncoderDecoder with mixin of Encoder and Decoder be implemented?

In the following snippet self.enc is not registered as a module.

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = nn.Linear(2,2)

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.dec = nn.Linear(2,2)
        
class EncoderDecoder(Encoder, Decoder):
    def __init__(self):
        super(EncoderDecoder, self).__init__()
        super(Encoder, self).__init__()
        print(self.modules)

EncoderDecoder()

output:

EncoderDecoder(
  (dec): Linear(in_features=2, out_features=2, bias=True)
)

You could remove the super(Encoder, self).__init__().
Generally you could inherit from multiple classes, but I think it would be easier to just register the Modules in EncoderDecoder:

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.enc = nn.Linear(2,2)

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.dec = nn.Linear(2,2)
        
class EncoderDecoder(nn.Module):
    def __init__(self):
        super(EncoderDecoder, self).__init__()
        self.enc = Encoder()
        self.dec = Decoder()
2 Likes