Why use MSELoss instead of mse_loss?

Is there a specific reason why torch.nn.functional.mse_loss is wrapped in the class torch.nn.MSELoss (and likewise for the other loss functions)? Feels slightly cumbersome to first instantiate a class and then call its forward. Is there some autograd magic going on?

Hi,

First you should never call the .forward() method of an nn.Module. But call the module instance with the inputs.

The advantage of the class version are:

  • For modules with parameters or buffers, it is nicer to just save these in self
  • It allows to use simple structure constructions like nn.Sequential()
  • It makes it easier to change some part of a model by only changing the class that is used, instead of rewriting the whole forward function just to change the activation function for example.

The functional version is necessary to bind to the backend and some users prefer to not use the nn.Module api and just use the functional api directly. So both of them are exposed and the user can choose to use one or the other (or mix) depending on what is more appropriate for the use case encountered.

2 Likes