Hi,
I am trying to use register_forward_pre_hook to modify one of my model’s forward input, and it seems the hook cannot recognize any of the input arguments. Here is my code which is a wrapper for the models of a library using keyword arguments for its forward function and I tried to separate two of the arguments which are necessary for modification.
class GNNWrapper(nn.Module):
""" the only point of this class is to pass `x` and `edge_index` as separate input
arguments and therefore torch hooks could recognize them"""
def __init__(self, model: GNNBasic) -> None:
super().__init__()
self._model: GNNBasic = model
def forward(self, x, edge_index, *args, **kwargs):
kwargs['x'] = x
kwargs['edge_index'] = edge_index
return self._model(*args, **kwargs)