How to check the gradients of custom implemented loss function


I implement a custom loss function (forward and backward). However, I donot know how to check whether the loss function is correctly implemented or not. I try to follow gradcheck:

While it is not very clear how to use the gradcheck. Can anyone provide some help? Thanks.

class WeightedDiceLoss4Organs(Function):
    def __init__(self, *args, **kwargs):

    def forward(self, inputs, targets, save=True):
        if save:
            self.save_for_backward(inputs, targets)
    def backward(self, grad_output):
        inputs, _ = self.saved_tensors # we need probabilities for input
        print 'type(grad_output): ',type(grad_output)
1 Like


torch.autograd.gradcheck takes as input the element to call to compute the forward pass and the input.
In your case, you want:

input = torch.rand(your_input_size)
target = torch.rand(your_target_size)
save = True
res = torch.autograd.gradcheck(WeightedDiceLoss4Organs(), (input, target, save), raise_exception=False)
print(res) # res should be True if the gradients are correct.

If your gradients are not correct, you can set raise_exception to True to know which gradient was not computed properly.
Moreover, keep in mind that this check is done by finite difference and so you must give an input point around which your function is smooth to get a valid gradient check.


Thanks a lot. I am trying your suggestion.