I need to build a neural network that consists of other smaller, fixed and already trained neural networks. For this, I create a hierarchy based on nn.Module children, as follows. I have a BaseModel, which is trained independently, I only need to train it once, and once its training is finished I do not need to touch it again. However, I want to have another model that it has some trainable parameters over this fixed BaseModel, so I just create another nn.Module instance. Both of them would look as follows:
class BaseModel(nn.Module): def __init__(self): super(BaseModel).__init__() self.linear = torch.nn.Linear(4,4) def forward(self,x): return self.linear(x) class EffectsModel(nn.Module): def __init__(self, base_model): super(EffectsModel).__init__() self.base_model = base_model # Freeze parameters from the base_model for params in self.base_model.parameters(): params.requires_grad = False self.other_effect = torch.nn.Linear(4,4) def forward(self,x): x = self.base_model(x) return self.other_effect(x)
Now I want to replicate as many instances of this EffectsModule as I want, for different experiments, so I create my CustomModel containing different instances of EffectsModel, that efficiently reuse the same BaseModel.
class CustomModel(nn.Module): def __init__(self, state_dict): super(CustomModel).__init__() self.base_model = BaseModel() self.base_model.load_state_dict(state_dict) self.m1 = EffectsModel(self.base_model) self.m2 = EffectsModel(self.base_model) ... self.mn = EffectsModel(self.base_model) def forward(self, x): x = self.m1(x) x = self.m2(x) ... x = self.mn(x) return x
When I debug and check inside CustomMode id(self.base_model) and id(self.m1.base_model)…id(self.mn.base_model) i get the same id. Indicating that I only have one instance in memory of base_model. However, when I call the state_dict function of CustomModel, I get a copy of each base_model inside self.m1 self.m2 self.mn, plus the self.base_model.
Now that I have introduced the problem, the question is simple. Is there any way in which I can have a single base_model registered in my state_dict? What is the most efficient way of doing this? Is there any way of unregistering the nn.Module related to base_model inside each EffectsModel? What is the internal way of working of the nn.Module in this case?
Thank you very much in advance!