Forward hook containing trainable parameters

I would like to modify an input to a forward call of a module I cannot modify using register_forward_pre_hook. The hook contains a forward call to another module with trainable parameters.

Let’s assume we have a module FixModule that we cannot modify. Also, we can’t change the input x to the forward call, as this is done somewhere deep down in a predefined PyTorch module.

class FixModule(nn.Module):
    def __init__(self):
        self.l1 = nn.Linear(1, 1)

    def forward(self, x):
        return self.l1(x)

I define a forward pre-hook like this:

class Linear(nn.Module):
    def __init__(self):
        self.l1 = nn.Linear(1, 1)

    def forward(self, x):
        return torch.sigmoid(self.l1(x))

linear = Linear()

def hook(self, input):
    return (linear(input[0]), )

model = FixModule()

The hook takes input x and applies a linear transformation followed by a sigmoid activation. However, the parameters of linear are not part of the model. I could add the parameters of both model and linear manually to the optimizer, but I am sure there is a nicer way to do this.

Is there a clean way to achieve this?


The way I would do this is to add the new layer directly on the Module:

def get_hook_for(mod):
    mod.my_linear = Linear()
    def hook_fn(self, input):
        # mod and self are the same here, so you can use any of them
        return (self.my_linear(input[0]), )
    return hook_fn

model = FixModule()

Awesome, thanks! Exactly what I was looking for.