Best/clean way to define your own loss function

Hi All,

I saw so many questions/answers here about custom loss function. However, didn’t find what I’m looking for.

my simple question is: From Pytorch’s perspective what is the best/clean way to write your own loss function- class-based or function-based?

PyTorch is rather modular. Any scalar value can be a loss, so any function of tensors, class labels, or parameters that returns a singular value is a loss function.

The only thing to it beyond coding style is to consider the convexity of the function and it’s limiting behavior (how does it look for extreme values?) and numerical stability (would you run into precision issues with float 32 numbers?).

1 Like