I was building custom loss function.
I cant use nn.Module
extension because as Ptrblck told, I cant use nn.module
when numpy operations is in it,
I couldnt use torch.
functions because numpy
have greater functions where i could try such as np.maximum
or things, which is wip in Pytorch. After cleaning ,my final loss function is below
class TripletLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, v1, v2, margin=0.25):
margin = torch.tensor(margin)
ctx.save_for_backward(v1, v2, margin)
scores = np.dot(v1.detach().numpy(), v2.detach().numpy().T)
batch_size = len(scores)
positive = np.diag(scores) # the positive ones (duplicates)
negative_without_positive = scores - 2.0 * np.identity(batch_size)
closest_negative = negative_without_positive.max(axis=1)
negative_zero_on_duplicate = scores * (1.0 - np.eye(batch_size))
mean_negative = np.sum(negative_zero_on_duplicate, axis=1) / (batch_size - 1)
triplet_loss1 = np.maximum(0.0, margin - positive + closest_negative)
triplet_loss2 = np.maximum(0.0, margin - positive + mean_negative)
triplet_loss = torch.mean(triplet_loss1 + triplet_loss2).requires_grad_()
return triplet_loss
@staticmethod
def backward(ctx, grad_output):
v1, v2, margin = ctx.saved_tensors
grad_v1, grad_v2, grad_margin = v1, v2, margin
return grad_v1, grad_v2, grad_margin
I’m not getting over it , How can i do it when using two or more arguments and numpy functions in it?
It doesnt work when using .backward
function, I tried to use bunch of tactics as described in example such as grad_input[input[0]<0] = 0
some gave error RuntimeError: too many indices for tensor of dimension 0
sometimes on changing codes gave RuntimeError: function TripletLossBackward returned an incorrect number of gradients (expected 2, got 1)
would i get any snippet example for i solving this kind of problem for future cases ,when using numpy functions and tracking using custom autograd functions?
Also a feature request It would be great if our Pytorch provides numpy
equivalent library that contains all numpy functions as that jax or other frameworks provided.