I was building custom loss function.

I cant use `nn.Module`

extension because as Ptrblck told, I cant use `nn.module`

when numpy operations is in it,

I couldnt use `torch.`

functions because `numpy`

have greater functions where i could try such as `np.maximum`

or things, which is wip in Pytorch. After cleaning ,my final loss function is below

```
class TripletLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, v1, v2, margin=0.25):
margin = torch.tensor(margin)
ctx.save_for_backward(v1, v2, margin)
scores = np.dot(v1.detach().numpy(), v2.detach().numpy().T)
batch_size = len(scores)
positive = np.diag(scores) # the positive ones (duplicates)
negative_without_positive = scores - 2.0 * np.identity(batch_size)
closest_negative = negative_without_positive.max(axis=1)
negative_zero_on_duplicate = scores * (1.0 - np.eye(batch_size))
mean_negative = np.sum(negative_zero_on_duplicate, axis=1) / (batch_size - 1)
triplet_loss1 = np.maximum(0.0, margin - positive + closest_negative)
triplet_loss2 = np.maximum(0.0, margin - positive + mean_negative)
triplet_loss = torch.mean(triplet_loss1 + triplet_loss2).requires_grad_()
return triplet_loss
@staticmethod
def backward(ctx, grad_output):
v1, v2, margin = ctx.saved_tensors
grad_v1, grad_v2, grad_margin = v1, v2, margin
return grad_v1, grad_v2, grad_margin
```

I’m not getting over it , How can i do it when using two or more arguments and numpy functions in it?

It doesnt work when using `.backward`

function, I tried to use bunch of tactics as described in example such as ` grad_input[input[0]<0] = 0`

some gave error `RuntimeError: too many indices for tensor of dimension 0`

sometimes on changing codes gave `RuntimeError: function TripletLossBackward returned an incorrect number of gradients (expected 2, got 1)`

would i get any snippet example for i solving this kind of problem for future cases ,when using numpy functions and tracking using custom autograd functions?

Also a feature request It would be great if our Pytorch provides `numpy`

equivalent library that contains all numpy functions as that jax or other frameworks provided.