Unregister / prevent from registering a nn.Module

I need to build a neural network that consists of other smaller, fixed and already trained neural networks. For this, I create a hierarchy based on nn.Module children, as follows. I have a BaseModel, which is trained independently, I only need to train it once, and once its training is finished I do not need to touch it again. However, I want to have another model that it has some trainable parameters over this fixed BaseModel, so I just create another nn.Module instance. Both of them would look as follows:

class BaseModel(nn.Module):
    def __init__(self):
        self.linear = torch.nn.Linear(4,4)

    def forward(self,x):
        return self.linear(x)

class EffectsModel(nn.Module):
    def __init__(self, base_model):
        self.base_model = base_model
        # Freeze parameters from the base_model
        for params in self.base_model.parameters():
            params.requires_grad = False
        self.other_effect = torch.nn.Linear(4,4)

    def forward(self,x):
        x = self.base_model(x)
        return self.other_effect(x)

Now I want to replicate as many instances of this EffectsModule as I want, for different experiments, so I create my CustomModel containing different instances of EffectsModel, that efficiently reuse the same BaseModel.

class CustomModel(nn.Module):
    def __init__(self, state_dict):
        self.base_model = BaseModel()

        self.m1 = EffectsModel(self.base_model)
        self.m2 = EffectsModel(self.base_model)
        self.mn = EffectsModel(self.base_model)

    def forward(self, x):
        x = self.m1(x)
        x = self.m2(x)
        x = self.mn(x)
        return x

When I debug and check inside CustomMode id(self.base_model) and id(self.m1.base_model)…id(self.mn.base_model) i get the same id. Indicating that I only have one instance in memory of base_model. However, when I call the state_dict function of CustomModel, I get a copy of each base_model inside self.m1 self.m2 self.mn, plus the self.base_model.

Now that I have introduced the problem, the question is simple. Is there any way in which I can have a single base_model registered in my state_dict? What is the most efficient way of doing this? Is there any way of unregistering the nn.Module related to base_model inside each EffectsModel? What is the internal way of working of the nn.Module in this case?

Thank you very much in advance!

You’re creating multiple instances of the inner model, each instance has it’s own weights and biases. So, when saving, you’re saving each one of them.

Thank you for your answer my3bikaht.

When I do id(self.base_model) and id(self.m1.base_model)…id(self.mn.base_model) I get the same value. To my understanding, when they have the same id value, they are just a reference to the same object.

In any case, do you know any way in which I could prevent having multiple instances, by just having a single reference to the same object?

You can just create one instance of inner model and pass all inputs through it. But in your example there’s no actual reason to have one. You can pass different inputs through the base model directly. But it really depends on what you want to achieve. If you want to train multiple (sub)models separately, each based on it’s own inputs, then your architecture is just fine. Otherwise it can damage training process, for example for triplet loss if you train one instance based on positive examples only and one on negative, both models may not converge.

So, instead of m1…mn you can create one instance m and pass all examples through it. Not sure about ids, just tested it here and there are different ids per instance.

Thank you for your reply.

The example is just a simplification of what I really need. The forward pass of the EffectsModel would be more like:

    def forward(self,x):
        x = self.some_effects(x)
        x = self.base_model(x)
        return self.other_effect(x)

In there, the effects of each m1 to m would be different. Also, I need the learning parameters of each m1 to mn to be different.

I am still looking for an optimal way to do this.