Getting None on custom loss function

Hi! I am trying to implement the paper: http://openaccess.thecvf.com/content_CVPR_2019/papers/Zhang_Cascaded_Generative_and_Discriminative_Learning_for_Microcalcification_Detection_in_Breast_CVPR_2019_paper.pdf from CVPR 2019.

The authors present a novel t-test based loss function which seems fairly straightforward to implement, however after spending countless hours, the gradients are not being calculated.

class calculate_asn_loss_t_test(nn.Module):
    """This is a customized class to calculate the loss for the ASN block of the architecture
    using the t-test loss, please note that in all the tensors with the batch size B
    the first B/2 examples are the positive patches and the next B/2 samples are negative patches"""

    def __init__(self, original_tensor, reconstructed_tensor, beta=0.8, lambda_p=1, lambda_n=0.01):
        super(calculate_asn_loss_t_test, self).__init__()
        self.original_tensor = original_tensor.cuda()
        self.reconstructed_tensor = reconstructed_tensor.cuda()
        self.residual_l = original_tensor.size(0)
        self.beta = beta
        self.lambda_p = lambda_p
        self.lambda_n = lambda_n
        self.residual_tensor = None
        self.pos_residual = None
        self.neg_residual = None
        self.pos_std = None
        self.neg_std = None


    def calculate_residual(self):
        """r(I) = Summation(|f(theta; I) - I|)
        pixel wise summation for each patch as stated in the paper"""
        try:
            self.residual_tensor = torch.abs(self.reconstructed_tensor - self.original_tensor).cuda()
            self.residual_tensor = self.residual_tensor.sum(dim=[1,2,3]).cuda()
        except Exception as e:
            print("[calculate_asn_loss_t_test][calculate_residual] Error occured : {}".format(e))


    def forward(self):
        self.calculate_residual()
        self.pos_residual = self.residual_tensor[:self.residual_l//2].float().cuda()
        self.neg_residual = self.residual_tensor[self.residual_l//2:].float().cuda()
        self.pos_mean = self.pos_residual.mean().cuda()
        self.neg_mean = self.neg_residual.mean().cuda()
        self.pos_std = self.pos_residual.std().cuda()
        self.neg_std = self.neg_residual.std().cuda()
        loss = torch.max(self.beta - self.pos_mean, torch.tensor([0.0]).cuda()) + self.neg_mean + self.lambda_p * (self.pos_std ** 2) + self.lambda_n * (self.neg_std ** 2)
        loss_ret = torch.autograd.Variable(loss, requires_grad=True)
        return loss_ret

Basically the idea is, that the image is being reconstructed in a U-net based autoencoder backbone. The tensor batch of size N, consists of N/2 positive samples, and N/2 negative samples.

            loss = calculate_asn_loss_t_test(input_loss.float(), output_asn.float())
            loss_ = loss.forward()
            # loss = calculate_asn_loss_t_test_f(input_loss.float(), output_asn.float()).float()
            # print("OUTPUT LOSS: ", loss)
            loss_.backward()

input_loss --> is the input tensor
output_asn --> is the output generated by the model
finally I am trying to backpropogate the gradients. However the grad=None always.

Is ther something fundamentally wrong here?

Hi,

Your implementation is problematic here:
loss_ret = torch.autograd.Variable(loss, requires_grad=True)
Beyond the fact that Variable are not used anymore and you should not use them, what this line means is: “take this Tensor, discard its past, and start tracking gradients for it”.

The autograd can only work of you perform all operations with pytorch methods and you don’t break the graph between where you want gradients and where you backward.

Also I would advise against saving stuff in self. for nn.Modules. As they should be stateless as much as possible.
If you want to be able to inspect these values later, you can simply return them in most cases.

Thanks a lot for the suggestions.

I’ll make the changes and let you know.

I made the changes as you suggested. The grads are now being calculated, however all the gradients are zero, and the weights are becoming Nan. This happens immediately after the first backpropogation. This implicates that there is something wrong with the Loss function, but I have straight away implemented the loss function as it was mentioned in the paper. I doubt that there should be any problems with the non differentiability of the function.

class calculate_asn_loss_t_test(nn.Module):
    """This is a customized class to calculate the loss for the ASN block of the architecture
    using the t-test loss, please note that in all the tensors with the batch size B
    the first B/2 examples are the positive patches and the next B/2 samples are negative patches"""

    def __init__(self):
        super(calculate_asn_loss_t_test, self).__init__()


    def forward(self, original_tensor, reconstructed_tensor, beta=0.8, lambda_p=1, lambda_n=0.01):
        residual_l = original_tensor.size(0)
        residual_tensor = torch.abs(reconstructed_tensor - original_tensor).cuda()
        residual_tensor = residual_tensor.sum(dim=[1,2,3]).cuda()
        pos_residual = residual_tensor[:residual_l//2].float().cuda()
        neg_residual = residual_tensor[residual_l//2:].float().cuda()
        pos_mean = pos_residual.mean().cuda()
        neg_mean = neg_residual.mean().cuda()
        pos_std = pos_residual.std().cuda()
        neg_std = neg_residual.std().cuda()
        loss = torch.max(beta - pos_mean, torch.tensor([0.0]).cuda()) + neg_mean + lambda_p * (pos_std ** 2) + lambda_n * (neg_std ** 2)
        return loss

Could you provide a small code sample (30-40 lines) that I can run that shows this behavior please?

I am really sorry, I forgot about this. There was a bug in the way I was loading the data which was causing the weights to go None. Perhaps it had nothing to do with the loss function.

Thanks a lot for your concern though.

1 Like