Every time before the forward() function of an nn.Module is called, I want to check and probably modify this nn.Module’s parameter, including weight and bias.
But the official source code of “register_forward_pre_hook” below doesn’t really say if this is achievable.
def register_forward_pre_hook(self, hook):
r"""Registers a forward pre-hook on the module.
The hook will be called every time before :func:`forward` is invoked.
It should have the following signature::
hook(module, input) -> None or modified input
The hook can modify the input. User can either return a tuple or a
single modified value in the hook. We will wrap the value into a tuple
if a single value is returned(unless that value is already a tuple).
Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._forward_pre_hooks)
self._forward_pre_hooks[handle.id] = hook
return handle
So the question is simple, could a forward pre hook modify the module?
Yes, you can modify the Module in any way you want during that hook (the module is the first arg of the hook function).
Note though that if you swap out weight Tensors, it might not behavior properly with other elements that were tracking the old Tensors like optimizers.
Do optimizers just use the name of the parameters to track 'em? I’m looking to use register_forward_pre_hook for equalized learning rate as described in the progressive GAN paper. In simple terms, the weight is divided by a constant after every learning step.
In this case, if I simply divide the weights in register_forward_pre_hook by say 2, how will it affect optimizer behavior?
I want to mask some of the weights before forward, so the hook needs to take the mask as input. But the hook signature hook(module, input) doesn’t allow additional params.
Is what i want possible? Or should i just store the mask as a module buffer?
EDIT: i looked at the code in torch.nn.utils.prune.py and it looks like register_buffer is the right way.
Thanks for the explaination. In my case, suppose I want to add a noise to the weights before the forward pass and restore the weights afterwards, what would be the better way for implementation without modifying the forward function of the module?
My way of doing it is to add a pre hook and storing the noises into a dict:
def get_hook(mean=0.0, std=0.1):
def hook(module, input):
with torch.no_grad():
if hasattr(module, 'weight'):
noise = torch.randn_like(module.weight) * std + mean
module.weight.add_(noise)
noises[''] = noise
return hook
and then delete the noise with forward hook:
def get_hook():
def hook(module, input, output):
with torch.no_grad():
if hasattr(module, 'weight'):
module.weight.add_(-noises[''])
return hook