Do we always need to define the forward function for a subclass of nn.Module?

A common practice of developing neural models is to define a subclass of nn.Module with both __init__ function and forward function inside it. Then we can create an instance model of the class we defined and call forward function directly use model(...). I have never questioned this before, until I found a PyTorch project on GitHub that the author doesn’t really define forward function, instead, the author create his own function to compute the loss based on the modules defined inside __init__. I don’t understand why he does that. In principle, are we always recommended using forward function to define the computation we want our model to do? If we don’t use forward and instead define our own function and then call the function explicitly, will there be any problem or disadvantage?

If you define an nn.Module, you are usually storing some submodules, parameters, buffers or other arguments in its __init__ method and write the actual forward logic in its forward method.
This is a convenient method as nn.Module.__call__ will register hooks etc. and call finally into the forward method.
However, you don’t need to use this approach and could completely write your model in a functional way.

I’ve written some more information in this thread which might be interesting to you.


But if we define the behavior of the neural model in our own function instead of defining forward and call __call__, do we need to take care of the hooks in our own function? Since __call__ will take care of all the hooks. Specifically, if we define a function called func, and then call model.func to do the computation, we will need to take care of hooks inside func by ourselves right?
Actually, I don’t quite get the necessity for the hooks. What are they essentially for? I suppose the hooks do not really affect the forward pass and back propagation. They are only used to do some auxiliary work such as visualizing the neural nets, e.t.c. In another words, our neural models can still work properly even we don’t register the hooks, e.g., we define the computation using our own entry point instead of forward.

Yes, your model might work without the forward usage.
However, utility functions such as nn.DataParallel rely on the __call__ method and thus on the implementation of forward.
Also, if other users would like to use your model, they would have to reimplement the forward pass in the forward, if they want to use hooks.

Hooks can be used for debugging or visualizations. But you could also easily add auxiliary losses to your model using forward hooks or some gradient regularization using register_hook.

I don’t see much benefit in using custom names.

1 Like