I would like to modify an input to a forward call of a module I cannot modify using register_forward_pre_hook
. The hook contains a forward call to another module with trainable parameters.
Let’s assume we have a module FixModule
that we cannot modify. Also, we can’t change the input x
to the forward call, as this is done somewhere deep down in a predefined PyTorch module.
class FixModule(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(1, 1)
def forward(self, x):
return self.l1(x)
I define a forward pre-hook like this:
class Linear(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(1, 1)
def forward(self, x):
return torch.sigmoid(self.l1(x))
linear = Linear()
def hook(self, input):
return (linear(input[0]), )
model = FixModule()
model.register_forward_pre_hook(hook)
The hook takes input x
and applies a linear transformation followed by a sigmoid activation. However, the parameters of linear
are not part of the model
. I could add the parameters of both model
and linear
manually to the optimizer, but I am sure there is a nicer way to do this.
Is there a clean way to achieve this?