Correct way to create wrapper modules around existing modules

Hi, everyone, I’m trying to create a wrapper module around an existing module that has parameters and I’m a bit worried that I may be registering the same parameter several times and modifying it multiple times during optim.step().

For instance, I’d like to create an embedding layer that uses the special dropout featured in the AWS_lstm paper. This involves creating an embedding layer and then wrapping it in a custom module:


embed = nn.Embedding(a,b)

class EmbedDrop(Module):
    def __init__(self, embedlayer, p):
        self.embed = embedlayer
        self.weight = embedlayer.weight
        self.dropout = p
    def forward(self, input):
        if not self.training: 
            return F.linear(input, self.weight)
        ...

so the embedding layer’s weight matrix has been ‘mentioned’ three times during different init calls, and twice during during the init call of the wrapping module. If I’m making a mistake and registering it twice, but for several reasons need the weight matrix to be an attribute of the wrapping layer, could I get around this by making it a property instead?


class EmbedDrop(Module):
    def __init__(self, embedlayer, p):
        self.embed = embedlayer
        self.dropout = p
    def forward(self, input):
            ...
   @property
   def weight(self):  return self.embed.weight

Just to emphasize the embedding dropout wrapper is just an example for illustration and this is a general question about avoiding registering the same parameter multiple times.

Does anyone know if registering the weight as a property also registers it as a parameter? Is it possible to register a parameter twice or is just an imaginary problem on my part that can’t actually happen. Is there a way of telling the module the init stage to not bother registering the paramter (because its already been registered)

Generalizing this and taking this slightly further, I’ve often found myself wanting to make all the attributes of the wrapped layer accessible as attributes of the wrapping layer, so Ive adding this to the init of the wrapper module:

class MyWrapper(Module):
    def __init__(self, wrapped_mod):
        self.wrapped = wrapped_mod
        non_hidden_attrs = toolz.keysfilter(lambda x : not x.startswith('_'), wrapped_mod.__dict__)
        self.__dict__.update(non_hidden_attrs)
    def forward(self, input):
    ...

Can anyone tell me if there are any special PyTorch/autograd specific reasons I should avoid doing this?

Thanks a lot for any help!

1 Like