In the code shown below, I would like to create an Encoder
class and then assign it into self.encoder
in ComponentEmbedding
class. However, when i try to access ComponentEmbedding().encoder
, the value is still None
. Anyone help?
Note: I need to preserve encoder = None
, in case create_encoder = False
.
Code:
import torch.nn as nn
class Encoder(nn.Module):
def __init__(self):
super().__init__()
pass
def forward(self, x):
pass
class ComponentEmbedding(nn.Module):
encoder = None
def __init__(self, create_encoder=True):
super().__init__()
if create_encoder:
self.init_encoder()
def forward(self, x):
pass
def init_encoder(self):
self.encoder = Encoder()
print(ComponentEmbedding().encoder)
Output:
None
Expected Output:
Encoder()
This should work
import torch.nn as nn
class Encoder(nn.Module):
def __init__(self):
super().__init__()
pass
def forward(self, x):
pass
class ComponentEmbedding(nn.Module):
def __init__(self, create_encoder=True):
super().__init__()
self.encoder = None
if create_encoder:
self.init_encoder()
def forward(self, x):
pass
def init_encoder(self):
self.encoder = Encoder()
print(ComponentEmbedding().encoder)
1 Like
Great, it’s working! Yet, I’m curious why doesn’t it work if we declare encoder = None
outside def
. I do also found out that if I declare Encoder()
as normal class (without nn.Module
), its works.
import torch.nn as nn
class Encoder():
def __init__(self):
super().__init__()
pass
def forward(self, x):
pass
class ComponentEmbedding(nn.Module):
encoder = None
def __init__(self, create_encoder=True):
super().__init__()
if create_encoder:
self.init_encoder()
def forward(self, x):
pass
def init_encoder(self):
self.encoder = Encoder()
print(ComponentEmbedding().encoder)
Output:
<__main__.Encoder object at 0x7f9eaddad3d0>
encoder = None
If you define it outside __init__
, then it is a class variable. What this means is that you can define many objects of this class and then by doing something like this ↓ you will modify this variable for every object created.
class SimpleClass():
var = 5
def __init__(self):
pass
simple1 = SimpleClass()
simple2 = SimpleClass()
print(simple1.var)
print(simple2.var)
SimpleClass.var = 3
print(simple1.var)
print(simple2.var)
# Output:
#5
#5
#3
#3
Here you can read a little bit more on class variables and standard practices.
But if you really want to use it like this, then you would still have to either initialize it for this object inside __init__
and it would be basically the same as my previous solution, only with the extra class variable definition OR using it as a class variable and modifying this variable for EVERY object
def init_encoder(self):
ComponentEmbedding.encoder = Encoder()
Encoder() without nn.Module
This works because you are creating an object of the type Encoder
. However, it does not have all the functionality that a nn.Module
has. It is only a simple class with an __init__
and forward
method, but this does not mean that it will function the same.
If you need the functionality of nn.Module
then it should inherit from it.
You can try this by attempting to use a standard nn.Module
method that you have still not implemented such as eval
.
class Encoder(torch.nn.Module):
def __init__(self):
super().__init__()
pass
def forward(self, x):
pass
enc = Encoder()
enc.eval() # WIthout torch.nn.Module this will result in error
1 Like
Great explanations, thank you so much
1 Like
TL;DR: Do not use class attributes inside nn.Module
. At least, not with the same name as instance attributes.
To answer your curiosity, I am just extending @Matias_Vasquez’s answer a bit more.
PyTorch’s nn.Module
does have its own internal structure under the hood to store the module
/parameter
/buffer
attributes. Whenever an attribute is assigned, the _setattr__
method of nn.Module
is called and the particular attribute value is stored under the appropriate internal dictionary (_parameters
, _modules
, _buffers
) depending on its type.
In your case, the instance attribute encoder
that you create inside init_encoder()
method will be stored under nn.Module
’s _modules
dictionary.
Now, when you access ComponentEmbedding().encoder
, it will be searched according to attribute lookup order of Python. When it looks in the class __dict__
first, it will find the encoder
attribute with value None
and its returned. Thus the class attribute encoder
does not give any chance for nn.Module
to call its own __getattr__
.
Additional reference on metaclasses, attribute lookup order.
2 Likes
Great explanations! solved my curiosity, thank you