Define backward() function in my own loss function

My own loss function is special and can not be solved by the ‘autograd’ function. so I want to define the backward() in my loss function. But I don’t know if I can get the gradient of loss by hand and make it back throughout the network automatically?

import torch

class MyLoss(torch.autograd.Function):

    @staticmethod
    def forward(self, input, label):
        ctx.save_for_backward(input, label)
        #my code
        return loss

    @staticmethod
    def backward(ctx, grad_output):
        input, label = ctx.saved_tensors
        # my code
        return grad_input, None


model = MyNetwork()#this is my network
myloss = MyLoss()

for t in range(500):
    output = model(input)
    loss = myloss(output,label)
    loss.backward()

Is the above code feasible? I am not sure if the gradient can automatic back throughout the network.
Thank you.

Yes you can do that. See more details in the doc.
Be careful:

  • You should not instantiate an instance of the Function. You should do loss = MyLoss.apply(output, label). If you prefer to instantiate an instance, you can simply create a new module like:
class MyLossMod(nn.Module):
    def forward(self, input, target):
        return MyLoss.apply(input, target)

# Then you can use any of the two
loss = MyLoss.apply(input, target)
# or
mylossmod = MyLossMod()
loss = mylossmod(input, target)
  • Your arguments for forward should be ctx not self but I guess that’s just a typo
1 Like

Thank you for your reply!

Please, when our loss function cannot be solved by autograd ? and if it is the case, do we need then to write two classes, one extending the Function class and the other extending nn.module ? (like in the tutorial https://pytorch.org/docs/master/notes/extending.html)
Thank you!

You will need to define a custom Function.
The custom Module presented there is just for convenience. You can use YourFunction.apply() in your code if you don’t want to use Modules.