Costum loss function, without a backward function, how it work?

Hi, I’m trying to write a custom loss, and I saw a few ways to do it:

  1. Make a new class inheriting from Function, with forward and backward functions,as in
    this git:
    https://github.com/mattmacy/torchbiomed/blob/master/torchbiomed/loss.py#L13
    Which, By the way, I have no idea what is the “grad_output” in the backward function means?
    (when I’m doing loss.backward() with a “normal” loss function I don’t give any argument)
    I’ve read this:https://pytorch.org/docs/stable/notes/extending.html
    But things are still messy…

  2. Just a small and nice function as in this git:
    https://github.com/perone/medicaltorch/blob/master/medicaltorch/losses.py#L4
    And then using it in:
    https://github.com/perone/medicaltorch/blob/master/examples/gmchallenge_unet.py#L111
    and the backward is in line: 115.

Now my (biggest) questions are, why should I do the first option, when the 2nd is much easier?
How come they can do backward() to just a function?

Can someone please help to understand this issue, I feel like everything is just a big mess…
Thanks!

Hi,

You need to use 1 if you need custom gradients for you loss function (if it’s not differentiable or the gradients returned by the autograd are not what you want) or if the autograd backward is too slow because of the size og the forward and you could implement it more efficiently by hand.
If you’re not in these cases, go with 2: it’s simpler and will always return correct gradients !

So in simple case as dice loss, option 2 will be fine right?
But there is still one thing that I must understand, they got this function:

And in here:https://github.com/perone/medicaltorch/blob/master/examples/gmchallenge_unet.py#L115
They’re doing backward() to this function, am I misunderstanding something, isn’t backward() is just for loss functions? how can he know that this is a loss function there is nothing special about this function.

I know that dice loss have many implementations just waiting for me to grab them, but I have to understand first, sorry.

Thanks, for your help.

The autograd engine in pytorch is much more general than neural net uses.
You can get gradients for anything you do with a Tensor not only the loss of a neural net.

2 Likes