How to check the gradients of custom implemented loss function

Hi,

I implement a custom loss function (forward and backward). However, I donot know how to check whether the loss function is correctly implemented or not. I try to follow gradcheck: http://pytorch.org/docs/master/notes/extending.html

While it is not very clear how to use the gradcheck. Can anyone provide some help? Thanks.

class WeightedDiceLoss4Organs(Function):
    def __init__(self, *args, **kwargs):
        pass

    def forward(self, inputs, targets, save=True):
        if save:
            self.save_for_backward(inputs, targets)
           .....
    def backward(self, grad_output):
        inputs, _ = self.saved_tensors # we need probabilities for input
        print 'type(grad_output): ',type(grad_output)
        ........
1 Like

Hi

torch.autograd.gradcheck takes as input the element to call to compute the forward pass and the input.
In your case, you want:

input = torch.rand(your_input_size)
target = torch.rand(your_target_size)
save = True
res = torch.autograd.gradcheck(WeightedDiceLoss4Organs(), (input, target, save), raise_exception=False)
print(res) # res should be True if the gradients are correct.

If your gradients are not correct, you can set raise_exception to True to know which gradient was not computed properly.
Moreover, keep in mind that this check is done by finite difference and so you must give an input point around which your function is smooth to get a valid gradient check.

4 Likes

Thanks a lot. I am trying your suggestion.