The scenario is to add a hook in the model forward and backward. Is it feasible to call real_model(x)
inside hook_func.func()
to perform forward of real_model
as bellow?
Pseudo code:
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
# model init
...
def forward(self, x):
# computation graph and returns x
return x
real_model = Model()
class hook_func(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
# some operation on x
print(x)
# call the real forward of a custom nn.Module()
x = real_model(x)
return x
...