I have a module which uses another module as basis, which I use like this:
class FirstModule(nn.Module): def __init__(self, secondModule): self.secondModule = secondModule #of self.add_module('secondModule', secondModule) #other things...
The problem with this is that the parameters of secondModule will show up in the firstModule parameter list, which I don’t want; I need an instance of the second module there, but I don’t need its parameters / won’t backpropagate through them.
So I resorted to wrap the second module instance in a list, so that it’s parameters are invisible:
class FirstModule(nn.Module): def __init__(self, secondModule): self.secondModule = [secondModule] #other things...
The issue with this (apart from being awkward) is that sometimes I would like pytorch to know that the secondModule is there. For example, when calling firstModule.cuda(), I would like secondModule.cuda() to be called, too, which won’t happen in this case.
So what is the cleanest way of solving the situation? Is there a way to remove the parameters of secondModule from the firstModule parameter list, but in such a way that other functions are aware that secondModule is there?