My own loss function is special and can not be solved by the ‘autograd’ function. so I want to define the backward() in my loss function. But I don’t know if I can get the gradient of loss by hand and make it back throughout the network automatically?
import torch
class MyLoss(torch.autograd.Function):
@staticmethod
def forward(self, input, label):
ctx.save_for_backward(input, label)
#my code
return loss
@staticmethod
def backward(ctx, grad_output):
input, label = ctx.saved_tensors
# my code
return grad_input, None
model = MyNetwork()#this is my network
myloss = MyLoss()
for t in range(500):
output = model(input)
loss = myloss(output,label)
loss.backward()
Is the above code feasible? I am not sure if the gradient can automatic back throughout the network.
Thank you.
Yes you can do that. See more details in the doc.
Be careful:
You should not instantiate an instance of the Function. You should do loss = MyLoss.apply(output, label). If you prefer to instantiate an instance, you can simply create a new module like:
class MyLossMod(nn.Module):
def forward(self, input, target):
return MyLoss.apply(input, target)
# Then you can use any of the two
loss = MyLoss.apply(input, target)
# or
mylossmod = MyLossMod()
loss = mylossmod(input, target)
Your arguments for forward should be ctx not self but I guess that’s just a typo
Please, when our loss function cannot be solved by autograd ? and if it is the case, do we need then to write two classes, one extending the Function class and the other extending nn.module ? (like in the tutorial https://pytorch.org/docs/master/notes/extending.html)
Thank you!
You will need to define a custom Function.
The custom Module presented there is just for convenience. You can use YourFunction.apply() in your code if you don’t want to use Modules.