Custom loss function, what's legal in forward pass

I’m new to Pytorch, i’m having trouble properly implementing custom loss function that will be compatible with backward() pass. For some trivial experiments it’s ok, for others i’m getting a “element 0 of tensors does not require grad and does not have a grad_fn” error.
I did check requires_grad on arguments for following examples, x True, y False, p False.

class cu_loss(torch.nn.Module):
def init(self):
super(cu_loss,self).init()
#1 def forward(self,x,y):
da = y - torch.argmax(x, dim=1)
result = {…}
return result
#2 def forward(self,x,y,p):
diff = torch.sum(y-p)
return diff

#1 forward fails simply after addition of torch.argmax function
#2 forward i have no idea, am i not allowed to simply substract tensors with minus sign? This seems to be a common problem.

I read from similar topics that i can’t use functions outside of pytorch, because some data can escape the internal state of the graph? Is there a resource with general pointers how to properly do that?

After further research i became aware that only FloatTensors can become part of the gradient. Now after type conversion i can perform any operations in forward.
Even though i jumped this hurdle, i welcome any further insights.

You can use any PyTorch operations in the forward pass, which have a valid backward function.
While the majority of the operations meet this condition, e.g. torch.argmax will detach the output from the computation graph, as this operation is not differentiable.

How do i determine if a function is differentiable at a glance? Is there an argmax analog that could be employed? Since i figured out TensorFloat is necessary as a result in forward and started experimenting, different experiments produce different anomalies and i’m yet to learn the pytorch ways.

You can pass an input tensor with requires_grad=True to a particular operation and check, if the output tensor has a valid .grad_fn as seen here:


x = torch.randn(10, 10, requires_grad=True)
out = torch.argmax(x, dim=1)
print(out.grad_fn)
> None

val, idx = torch.max(x, dim=1)
print(val.grad_fn)
> <MaxBackward0 object at ...>
print(idx.grad_fn)
> None

Also, it’s often helpful to “draw” the function and think how the gradient could look like.

Yes, softmax (sometimes also called softargmax) is the smooth version of argmax, which provides valid gradients.

Thank you for the tips. Plotting the values certainly provides some insight. Enabling custom losses already allowed my model to reach ~20% accuracy, that’s over ~6% of what vanilla losses offered.
Cheers!

Hello @ptrblck (and Aoo)!

As a minor technical clarification, softmax() is a smooth version of
the one_hot() encoding of the result of argmax().

Best.

K. Frank

1 Like