Unable to assign nn.Module class to self.encoder in another nn.Module class

In the code shown below, I would like to create an Encoder class and then assign it into self.encoder in ComponentEmbedding class. However, when i try to access ComponentEmbedding().encoder, the value is still None. Anyone help?

Note: I need to preserve encoder = None, in case create_encoder = False.

Code:

import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        pass
    
    def forward(self, x):
        pass

class ComponentEmbedding(nn.Module):
    
    encoder = None
    
    def __init__(self, create_encoder=True):
        super().__init__()
        if create_encoder:
            self.init_encoder()
        
    def forward(self, x):
        pass
    
    def init_encoder(self):
        self.encoder = Encoder()

print(ComponentEmbedding().encoder)

Output:

None

Expected Output:

Encoder()

This should work

import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        pass
    
    def forward(self, x):
        pass

class ComponentEmbedding(nn.Module):
    def __init__(self, create_encoder=True):
        super().__init__()
        self.encoder = None
        if create_encoder:
            self.init_encoder()
        
    def forward(self, x):
        pass
    
    def init_encoder(self):
        self.encoder = Encoder()

print(ComponentEmbedding().encoder)
1 Like

Great, it’s working! Yet, I’m curious why doesn’t it work if we declare encoder = None outside def. I do also found out that if I declare Encoder() as normal class (without nn.Module), its works.

import torch.nn as nn

class Encoder():
    def __init__(self):
        super().__init__()
        pass
    
    def forward(self, x):
        pass

class ComponentEmbedding(nn.Module):
    
    encoder = None
    
    def __init__(self, create_encoder=True):
        super().__init__()
        if create_encoder:
            self.init_encoder()
        
    def forward(self, x):
        pass
    
    def init_encoder(self):
        self.encoder = Encoder()

print(ComponentEmbedding().encoder)

Output:

<__main__.Encoder object at 0x7f9eaddad3d0>

encoder = None

If you define it outside __init__, then it is a class variable. What this means is that you can define many objects of this class and then by doing something like this ↓ you will modify this variable for every object created.

class SimpleClass():
    var = 5
    def __init__(self):
        pass

simple1 = SimpleClass()
simple2 = SimpleClass()

print(simple1.var)
print(simple2.var)

SimpleClass.var = 3

print(simple1.var)
print(simple2.var)

# Output:
#5
#5
#3
#3

Here you can read a little bit more on class variables and standard practices.

But if you really want to use it like this, then you would still have to either initialize it for this object inside __init__ and it would be basically the same as my previous solution, only with the extra class variable definition OR using it as a class variable and modifying this variable for EVERY object

def init_encoder(self):
    ComponentEmbedding.encoder = Encoder()

Encoder() without nn.Module

This works because you are creating an object of the type Encoder. However, it does not have all the functionality that a nn.Module has. It is only a simple class with an __init__ and forward method, but this does not mean that it will function the same.

If you need the functionality of nn.Module then it should inherit from it.

You can try this by attempting to use a standard nn.Module method that you have still not implemented such as eval.

class Encoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        pass
    
    def forward(self, x):
        pass

enc = Encoder()

enc.eval() # WIthout torch.nn.Module this will result in error
1 Like

Great explanations, thank you so much :heart:

1 Like

TL;DR: Do not use class attributes inside nn.Module. At least, not with the same name as instance attributes.

To answer your curiosity, I am just extending @Matias_Vasquez’s answer a bit more.

PyTorch’s nn.Module does have its own internal structure under the hood to store the module/parameter/buffer attributes. Whenever an attribute is assigned, the _setattr__ method of nn.Module is called and the particular attribute value is stored under the appropriate internal dictionary (_parameters, _modules, _buffers) depending on its type.

In your case, the instance attribute encoder that you create inside init_encoder() method will be stored under nn.Module's _modules dictionary.

Now, when you access ComponentEmbedding().encoder, it will be searched according to attribute lookup order of Python. When it looks in the class __dict__ first, it will find the encoder attribute with value None and its returned. Thus the class attribute encoder does not give any chance for nn.Module to call its own __getattr__.

Additional reference on metaclasses, attribute lookup order.

2 Likes

Great explanations! solved my curiosity, thank you :heart: