Custom loss implementation (function vs class)

What would make one implement a custom loss as a class instead of as a regular function?

For instance, this implementation of the YOLOv1 loss function is done by expanding nn.Module. Whereas this MSE loss is implemented via a function.

I guess my question is, why is every loss implemented as a class in PyTorch if one can simply define a function as explained here?

The advantage of using classes (as nn.Module in this case) is that they can store internal states without explicitly passing all arguments. E.g. if you want to get the unreduced loss you could use the module via:

x = torch.randn(10, 10)
y = torch.randn(10, 10)

# module
criterion = nn.MSELoss(reduction='none')
loss = criterion(x, y)

# functional API
loss = F.mse_loss(x, y, reduction='none')

so it might be convenient to create the criterion once and just reuse it.
The same applies to a weighted loss or any module.
E.g. for a linear layer you could use both APIs as well:

x = torch.randn(10, 10)

# module
lin = nn.Linear(10, 10)
out = lin(x)

# functional API
weight = nn.Parameter(torch.randn(10, 10))
bias = nn.Parameter(torch.randn(10))
out = F.linear(x, weight, bias)

Depending on the flexibility you need and your actual workload, pick what fits your use case the best.

Thank you very much for answering the question. As a follow-up, If I am doing an implementation from scratch of a paper for learning purposes where they propose a new loss (i.e. RetinaNet or YOLO), am I better off implementing them as a function and putting them inside a utils/metrics file or should they have their own dedicated class that implements them? Knowing that I am only going to call them once and use them only in that particular repository.

I would probably implement the new loss function as an nn.Module and try to provide the same or similar arguments as the default loss functions in PyTorch (i.e. the reduction parameter at least, maybe a weight etc.).

EDIT: internally inside the nn.Module you could (and probably should) use the functional form, so that users could in fact pick which approach they want to use (PyTorch uses the same approach of calling the functional API inside the nn.Modules).

1 Like