I need to build a neural network that consists of other smaller, fixed and already trained neural networks. For this, I create a hierarchy based on nn.Module children, as follows. I have a BaseModel, which is trained independently, I only need to train it once, and once its training is finished I do not need to touch it again. However, I want to have another model that it has some trainable parameters over this fixed BaseModel, so I just create another nn.Module instance. Both of them would look as follows:
class BaseModel(nn.Module):
def __init__(self):
super(BaseModel).__init__()
self.linear = torch.nn.Linear(4,4)
def forward(self,x):
return self.linear(x)
class EffectsModel(nn.Module):
def __init__(self, base_model):
super(EffectsModel).__init__()
self.base_model = base_model
# Freeze parameters from the base_model
for params in self.base_model.parameters():
params.requires_grad = False
self.other_effect = torch.nn.Linear(4,4)
def forward(self,x):
x = self.base_model(x)
return self.other_effect(x)
Now I want to replicate as many instances of this EffectsModule as I want, for different experiments, so I create my CustomModel containing different instances of EffectsModel, that efficiently reuse the same BaseModel.
class CustomModel(nn.Module):
def __init__(self, state_dict):
super(CustomModel).__init__()
self.base_model = BaseModel()
self.base_model.load_state_dict(state_dict)
self.m1 = EffectsModel(self.base_model)
self.m2 = EffectsModel(self.base_model)
...
self.mn = EffectsModel(self.base_model)
def forward(self, x):
x = self.m1(x)
x = self.m2(x)
...
x = self.mn(x)
return x
When I debug and check inside CustomMode id(self.base_model) and id(self.m1.base_model)…id(self.mn.base_model) i get the same id. Indicating that I only have one instance in memory of base_model. However, when I call the state_dict function of CustomModel, I get a copy of each base_model inside self.m1 self.m2 self.mn, plus the self.base_model.
Now that I have introduced the problem, the question is simple. Is there any way in which I can have a single base_model registered in my state_dict? What is the most efficient way of doing this? Is there any way of unregistering the nn.Module related to base_model inside each EffectsModel? What is the internal way of working of the nn.Module in this case?
Thank you very much in advance!