Defining additional methods in an nn.Module

Hi everyone,
I am wondering whether it is ok to define additional function in an nn.Module:

class MyModule(nn.Module):
    def classify(...):
    def detect(...):
    def forward(...):

My question is based on this previous thread: Any different between model(input) and model.forward(input) which says that one should not call forward directly because some hooks are not used then. So is it discouraged define additional methods? Putting everything in forward() with some flag would be an alternative, but an ugly one.

I am not sure about your exact use case. In my understanding, if you are not using any forward or backward hooks, you can always define additional methods inside nn.Module.

Note that the hooks would only work if model(input) is called, which in turn processes the hooks internally before and after calling forward().

You can have a look at the code here:

Thanks for your answer.

I guess this means there should be no problems to use other methods for inference and model() for training.

Yes you are correct.

1 Like