How to compute autograd for custom pytorch function?

I was building custom loss function.

I cant use nn.Module extension because as Ptrblck told, I cant use nn.module when numpy operations is in it,

I couldnt use torch. functions because numpy have greater functions where i could try such as np.maximum or things, which is wip in Pytorch. After cleaning ,my final loss function is below

class TripletLoss(torch.autograd.Function):
    @staticmethod
    def forward(ctx, v1, v2, margin=0.25):
        margin = torch.tensor(margin)
        ctx.save_for_backward(v1, v2, margin)
        scores = np.dot(v1.detach().numpy(), v2.detach().numpy().T)
        batch_size = len(scores)
        positive = np.diag(scores) # the positive ones (duplicates)
        negative_without_positive = scores - 2.0 * np.identity(batch_size)
        closest_negative = negative_without_positive.max(axis=1)
        negative_zero_on_duplicate =  scores * (1.0 - np.eye(batch_size))
        mean_negative = np.sum(negative_zero_on_duplicate, axis=1) / (batch_size - 1)
        triplet_loss1 = np.maximum(0.0, margin - positive + closest_negative)
        triplet_loss2 = np.maximum(0.0, margin - positive + mean_negative)
        triplet_loss = torch.mean(triplet_loss1 + triplet_loss2).requires_grad_()
        return triplet_loss
    
    @staticmethod
    def backward(ctx, grad_output):
        v1, v2, margin = ctx.saved_tensors
        grad_v1, grad_v2, grad_margin = v1, v2, margin
        return grad_v1, grad_v2, grad_margin

I’m not getting over it , How can i do it when using two or more arguments and numpy functions in it?
It doesnt work when using .backward function, I tried to use bunch of tactics as described in example such as grad_input[input[0]<0] = 0
some gave error RuntimeError: too many indices for tensor of dimension 0 sometimes on changing codes gave RuntimeError: function TripletLossBackward returned an incorrect number of gradients (expected 2, got 1)

would i get any snippet example for i solving this kind of problem for future cases ,when using numpy functions and tracking using custom autograd functions?

Also a feature request It would be great if our Pytorch provides numpy equivalent library that contains all numpy functions as that jax or other frameworks provided.

Hi,

np.maximum or things, which is wip in Pytorch

Not sure what you mean by that. You can use torch.max() to get element-wise maximum. Or you can use the threshold function if you just want a threshold as seems to be the case in your code.

I would recommend you read the doc that explains how to write a custom Function here: Extending PyTorch — PyTorch 2.1 documentation
That will solve your issues with wrong number and type of returns from the backward.

Also a feature request It would be great if our Pytorch provides numpy equivalent library that contains all numpy functions as that jax or other frameworks provided.

This is already work in progress. You can follow the work in this issue: https://github.com/pytorch/pytorch/issues/38349

1 Like

Final solution for HardTripletloss works pretty well with all numpy eq functions in pytorch , after upgrading Pytorch to unstable version (1.8…) where it consisted all functions i needed to do operation and sticking to nn , found Autograd too hard, will not use it at least for now.

class HardTripletLoss(nn.Module):
    def __init__(self):
        super(HardTripletLoss, self).__init__()

    def forward(self, v1, v2, margin):
        scores = v1 @ v2.T
        batch_size = len(scores)
        positive = torch.diag(scores)
        negative_without_positive = scores - 2.0 * torch.eye(batch_size)
        closest_negative = negative_without_positive.max(axis=1)[0]
        negative_zero_on_duplicate = scores * (1.0 - torch.eye(batch_size))
        mean_negative = torch.sum(negative_zero_on_duplicate, 1) / (batch_size - 1)
        triplet_loss1 = torch.maximum(margin - positive + mean_negative, torch.tensor(0))
        triplet_loss2 = torch.maximum(margin - positive + closest_negative, torch.tensor(0))
        triplet_loss = torch.mean(triplet_loss2 + triplet_loss1)
        return triplet_loss

Thanks :slight_smile: