How do I pass a keyword argument to the forward used by a pre-forward hook?

Given a torch’s nn.Module with a pre-forward hook, e.g.

import torch
import torch.nn as nn

class NeoEmbeddings(nn.Embedding):
    def __init__(self, num_embeddings:int, embedding_dim:int, padding_idx=-1):
        super().__init__(num_embeddings, embedding_dim, padding_idx)
        self.register_forward_pre_hook(self.neo_genesis)
    
    @staticmethod
    def neo_genesis(self, input, higgs_bosson=0):
        if higgs_bosson:
            input = input + higgs_bosson
        return input

It’s possible to let an input tensor go through some manipulation before going to the actual forward() function, e.g.

>>> x = NeoEmbeddings(10, 5, 1)
>>> x.forward(torch.tensor([0,2,5,8]))
tensor([[-1.6449,  0.5832, -0.0165, -1.3329,  0.6878],
        [-0.3262,  0.5844,  0.6917,  0.1268,  2.1363],
        [ 1.0772,  0.1748, -0.7131,  0.7405,  1.5733],
        [ 0.7651,  0.4619,  0.4388, -0.2752, -0.3018]],
       grad_fn=<EmbeddingBackward>)

>>> print(x._forward_pre_hooks)
OrderedDict([(25, <function NeoEmbeddings.neo_genesis at 0x1208d10d0>)])

How could we pass the arguments (*args or **kwargs) that the pre-forward hook needs but not accepted by the default forward() function?

Without modification/overriding the forward() function, this is not possible:

>>> x = NeoEmbeddings(10, 5, 1)
>>> x.forward(torch.tensor([0,2,5,8]), higgs_bosson=2)

----------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-102-8705a40a3cc2> in <module>
      1 x = NeoEmbeddings(10, 5, 1)
----> 2 x.forward(torch.tensor([0,2,5,8]), higgs_bosson=2)

TypeError: forward() got an unexpected keyword argument 'higgs_bosson'

Also on https://stackoverflow.com/questions/57703808/how-do-i-pass-a-keyword-argument-to-the-forward-used-by-a-pre-forward-hook

Hi,

Why is the forward pre-hook necessary here? Why not include it at the beginning of the forward?
Or if you want to inherit the forward from the parent class, create a new forward, do the preprocessing and then call the parent forward with super().forward(args).