Use forward_pre_hook to modify nn.Module parameters

Every time before the forward() function of an nn.Module is called, I want to check and probably modify this nn.Module’s parameter, including weight and bias.

But the official source code of “register_forward_pre_hook” below doesn’t really say if this is achievable.

    def register_forward_pre_hook(self, hook):
        r"""Registers a forward pre-hook on the module.

        The hook will be called every time before :func:`forward` is invoked.
        It should have the following signature::

            hook(module, input) -> None or modified input

        The hook can modify the input. User can either return a tuple or a
        single modified value in the hook. We will wrap the value into a tuple
        if a single value is returned(unless that value is already a tuple).

        Returns:
            :class:`torch.utils.hooks.RemovableHandle`:
                a handle that can be used to remove the added hook by calling
                ``handle.remove()``
        """
        handle = hooks.RemovableHandle(self._forward_pre_hooks)
        self._forward_pre_hooks[handle.id] = hook
        return handle

So the question is simple, could a forward pre hook modify the module?

P.S. I’m not really sure if this is the same question raised here Add hookable weights · Issue #5790 · pytorch/pytorch · GitHub

1 Like

Hi,

Yes, you can modify the Module in any way you want during that hook (the module is the first arg of the hook function).
Note though that if you swap out weight Tensors, it might not behavior properly with other elements that were tracking the old Tensors like optimizers.

Do optimizers just use the name of the parameters to track 'em? I’m looking to use register_forward_pre_hook for equalized learning rate as described in the progressive GAN paper. In simple terms, the weight is divided by a constant after every learning step.

In this case, if I simply divide the weights in register_forward_pre_hook by say 2, how will it affect optimizer behavior?

The optimizers track the weights as Tensors directly.
So as long as you modify that Tensor inplace, they will work just fine.

But if you do model.weight = foo then the new model.weight will not correspond to the one that is in the optimizer and so no learning will happen.

1 Like

Hi albanD,

I want to mask some of the weights before forward, so the hook needs to take the mask as input. But the hook signature hook(module, input) doesn’t allow additional params.

Is what i want possible? Or should i just store the mask as a module buffer?

EDIT: i looked at the code in torch.nn.utils.prune.py and it looks like register_buffer is the right way.

Hi,

The hook can capture any variable from the parent scope. So if the mask is known ahead of time, you can do something like:

def get_hook(mask):
  def hook(mod, input):
    # Do something with mod, input and mask
  return hook

mod.register_forward_pre_hook(get_hook(mask))
1 Like

Thanks and just a note for others: if the mask is updated, you should delete the hook and register a new one with the updated mask.

mask = old_mask
handle = mod.register_forward_pre_hook(get_hook(mask))
mask = new_mask
handle.remove()
handle = mod.register_forward_pre_hook(get_hook(mask))

Hi: albanD:

Thanks for the explaination. In my case, suppose I want to add a noise to the weights before the forward pass and restore the weights afterwards, what would be the better way for implementation without modifying the forward function of the module?
My way of doing it is to add a pre hook and storing the noises into a dict:

def get_hook(mean=0.0, std=0.1):
    def hook(module, input):
        with torch.no_grad():
            if hasattr(module, 'weight'):
                noise = torch.randn_like(module.weight) * std + mean
                module.weight.add_(noise)
                noises[''] = noise
    return hook

and then delete the noise with forward hook:

def get_hook():
    def hook(module, input, output):
        with torch.no_grad():
            if hasattr(module, 'weight'):
                module.weight.add_(-noises[''])
    return hook

Hey!

This looks like a textbook use case for reparametrization! Parametrizations Tutorial — PyTorch Tutorials 2.2.2+cu121 documentation

1 Like