How to exclude parameters from model?


I have a module which uses another module as basis, which I use like this:

class FirstModule(nn.Module):
    def __init__(self, secondModule):
        self.secondModule = secondModule #of self.add_module('secondModule', secondModule)
        #other things...

The problem with this is that the parameters of secondModule will show up in the firstModule parameter list, which I don’t want; I need an instance of the second module there, but I don’t need its parameters / won’t backpropagate through them.

So I resorted to wrap the second module instance in a list, so that it’s parameters are invisible:

class FirstModule(nn.Module):
    def __init__(self, secondModule):
        self.secondModule = [secondModule]
        #other things...

The issue with this (apart from being awkward) is that sometimes I would like pytorch to know that the secondModule is there. For example, when calling firstModule.cuda(), I would like secondModule.cuda() to be called, too, which won’t happen in this case.

So what is the cleanest way of solving the situation? Is there a way to remove the parameters of secondModule from the firstModule parameter list, but in such a way that other functions are aware that secondModule is there?

1 Like

If you do not want to backprop through the parameters of self.secondModule, you could do:

for p in self.secondModule.parameters():
    p.requires_grad = False

Thank you for your answer! But not only I don’t want to backprop, I don’t want those parameters to show up in self.parameters() (as I need to do something on them that I don’t want to do on the parameters of secondModule) :slight_smile:

filter through the parameters with some kind of lambda:

params = list(decoder.parameters())

for p in model.parameters()
#filter here
optimizer = torch.optim.Adam(new_params, lr=args.learning_rate)