My own loss function is special and can not be solved by the ‘autograd’ function. so I want to define the backward() in my loss function. But I don’t know if I can get the gradient of loss by hand and make it back throughout the network automatically？
import torch class MyLoss(torch.autograd.Function): @staticmethod def forward(self, input, label): ctx.save_for_backward(input, label) #my code return loss @staticmethod def backward(ctx, grad_output): input, label = ctx.saved_tensors # my code return grad_input, None model = MyNetwork()#this is my network myloss = MyLoss() for t in range(500): output = model(input) loss = myloss(output,label) loss.backward()
Is the above code feasible? I am not sure if the gradient can automatic back throughout the network.