Given a torch’s nn.Module
with a pre-forward hook, e.g.
import torch
import torch.nn as nn
class NeoEmbeddings(nn.Embedding):
def __init__(self, num_embeddings:int, embedding_dim:int, padding_idx=-1):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.register_forward_pre_hook(self.neo_genesis)
@staticmethod
def neo_genesis(self, input, higgs_bosson=0):
if higgs_bosson:
input = input + higgs_bosson
return input
It’s possible to let an input tensor go through some manipulation before going to the actual forward()
function, e.g.
>>> x = NeoEmbeddings(10, 5, 1)
>>> x.forward(torch.tensor([0,2,5,8]))
tensor([[-1.6449, 0.5832, -0.0165, -1.3329, 0.6878],
[-0.3262, 0.5844, 0.6917, 0.1268, 2.1363],
[ 1.0772, 0.1748, -0.7131, 0.7405, 1.5733],
[ 0.7651, 0.4619, 0.4388, -0.2752, -0.3018]],
grad_fn=<EmbeddingBackward>)
>>> print(x._forward_pre_hooks)
OrderedDict([(25, <function NeoEmbeddings.neo_genesis at 0x1208d10d0>)])
How could we pass the arguments (*args
or **kwargs
) that the pre-forward hook needs but not accepted by the default forward()
function?
Without modification/overriding the forward()
function, this is not possible:
>>> x = NeoEmbeddings(10, 5, 1)
>>> x.forward(torch.tensor([0,2,5,8]), higgs_bosson=2)
----------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-102-8705a40a3cc2> in <module>
1 x = NeoEmbeddings(10, 5, 1)
----> 2 x.forward(torch.tensor([0,2,5,8]), higgs_bosson=2)
TypeError: forward() got an unexpected keyword argument 'higgs_bosson'