How to implement my own gradient for a loss function and mix it with other standard loss in joint loss


Now I need to implement a joint loss L = L1 + L2 according to predicted pixel maps for my computer vision research, and in L1 I used the standard loss pytorch provides such as MSE. However, in loss L2, it involves several sparse matrix multiplications and I think there is no autograd support for sparse matrix operations.

Hence, if I want to implement a new class for my L2 loss and implement both forward and backward by myself, which class should I inherit? torch.autograd.Function or nn.Module? So how do I implement the gradient of my own loss function?

In addition, how do I call the functions so that the gradient of the loss w.r.t to predicted output from both L1 and L2 can be obtained, if L1’s autograd had been implemented by pytorch and L2’s gradient was implemented by myself.


1 Like

You should create an autograd function that makes all your sparse operation:

class SparseOperations(autograd.Function):
    def forward(self, input, weight, bias=None):
        # here comes the sparse operations
        return output

    def backward(self, grad_output):
        # here you compute the gradient of your sparse operation by hands
        return grad_input

Then you create a loss function from nn.Module:

class AwsomeLoss(nn.Module):
    def __init__(self):
        self.sparse_operation = SparseOperation()

    def forward(self, input, target):
        loss = self.sparse_operation(input, target)
        return loss

Finaly, you can directly sum L = L1 + L2 and call L.backward(). If your gradient is well computed, it should work.