I am wondering whether it is ok to define additional function in an nn.Module:
My question is based on this previous thread:
Any different between model(input) and model.forward(input) which says that one should not call forward directly because some hooks are not used then. So is it discouraged define additional methods? Putting everything in
forward() with some flag would be an alternative, but an ugly one.
I am not sure about your exact use case. In my understanding, if you are not using any
backward hooks, you can always define additional methods inside
Note that the hooks would only work if
model(input) is called, which in turn processes the hooks internally before and after calling
You can have a look at the code here:
This file has been truncated.
def _call_impl(self, *input, **kwargs):
forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
# If we don't have any hooks, we want to skip the rest of the logic in
# this function, and just call forward.
if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
or _global_forward_hooks or _global_forward_pre_hooks):
return forward_call(*input, **kwargs)
# Do not call functions when jit is used
full_backward_hooks, non_full_backward_hooks = , 
if self._backward_hooks or _global_backward_hooks:
full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
if _global_forward_pre_hooks or self._forward_pre_hooks:
for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()):
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
bw_hook = None
Thanks for your answer.
I guess this means there should be no problems to use other methods for inference and
model() for training.