# Help with Implementing a custom Loss Function

Hello, I am trying to implement this loss function taken from Section 2.1 of Right for the Right Reasons: Training Differentiable Models by Constraining their Explanations (Ross, et al., 2017)

The first and third term are the Cross-entropy loss and L2 regularization, respectively and are already implemented in Pytorch. The matrix A is a binary mask with dims (Num of samples, W, H, #Color channel). The new loss can be called like this:

``````optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.01)  # L2 is weight decay

for batch in batches:
x, target = batch
logits = model(x)
loss = explanation_loss(logits, target, l_one, A[index])
loss.backward()
optimizer.step()

``````

I am confused on how to implement the 2nd term (right reasons).

Specifically:

1. if my pytorch model outputs logits, I must first convert them to probabilities with softmax and then into a one-hot encoding to take the log, right?

2. How do I implement the input gradient here? (partial derivative w.r.t. x in the image)

Below is my pseudo-code that I want to make functional. Here is the original code that Iâ€™m having trouble to understand (In the function callled â€śobjectiveâ€ť). Thank you!

``````def explanation_loss(logits, target, l_one, mask_matrix):
ce = nn.CrossEntropyLoss()(logits, target)
if l_one==0:
return ce
pred = F.softmax(logits, dim=1)
label_one_hot = torch.nn.functional.one_hot(target, num_classes).float().to(device)
label_one_hot = torch.clamp(label_one_hot, min=1e-4, max=1.0)  # prevent log(0)

right_reasons = l_one* (torch.sum(mask_matrix * torch.log(label_one_hot), dim=1)).mean()
return ce + right_reasons
``````
1. You hardly ever need one hot labels, I donâ€™t think you need them here. Also, I would use torch.nn.functional.cross_entropy_loss instead of re-instantiating the module over and over. You do need predictions, but these are per-class encoded anyways.
You would want to use `log_softmax` rather than softmax + log.
2. For the loss involving a gradient w.r.t. the inputs, youâ€™d need to use use `requires_grad_()` on the inputs and then use `gr_input, = torch.autograd.grad(sum_log_hat_y, input) ` to get the gradient of your loss computation.
3. If you believe your input gradients to be reasonably independent of the other inputs in the batch (this can be problematic for BatchNorm but is fine for most other things), you can sum over the batch before taking the gradient, this will make things much faster because you only need one gradient.
4. I think you are missing a **2 before summing.

Best regards

Thomas

Hello @tom, thank you for the guidance. Now an error is thrown about trying to call the graph a second time when I call loss.backward(). But if I am correct, I just need the gradient values when computing `input_grad` and the first graph can be dropped. Do you have further advice on how to solve this without specifying `retain_graph=True` so I avoid an OOM error?

Below is my current pseudo-code:.

``````def explanation_loss(model, x, y, l_one):
logits = model(x)
ce = F.cross_entropy(logits, target)

if l_one==0:
return ce

sum_log_hat_y = torch.sum(F.log_softmax(logits, dim=1)).mean()
# same dims as x

right_reasons = l_one * (dot_prod) ** 2

# Mask matrix and loss are calculated in a single pass
x, target = batch

loss, A_batch = explanation_loss(model, x, target, 0.01)

loss.backward()  # Runtime Error
optimizer.step()
``````

In the grad call you would need both retain_graph and create_graph.

1 Like

The loss works as fine as cross entropy when `l_one = 0`.

However, it does not decrease when I pass a zero matrix. It should behave the same as above because the custom term is multiplied by zero. This is validated by the fact that my assertion statements pass. @tom does `autograd.grad` influence how the optimizer updates the model parameters, or is this a potential bug for autograd?

``````
def explanation_loss(logits, x, target, \
l_one=0.001):

ce = F.cross_entropy(logits, target)
if l_one == 0:
return ce

sum_log_hat_y = torch.sum(F.log_softmax(logits, dim=1)).mean()
# same dims as x

assert dot_prod == 0, dot_prod.item()

right_reasons = l_one * ((dot_prod) ** 2)
assert right_reasons == 0, right_reasons.item()  # These assertions pass

return ce + right_reasons
``````

I think the torch.dot is wrong. My understanding was that it should be * and then

``````right_reasons = l_one * ((dot_prod) ** 2).sum()
``````

But this doesnâ€™t explain why you would have trouble with the right_reasons term if the mask matrix is all zerosâ€¦
Do you have a toy example to demonstrate this, maybe form the paper?

@tom Oops, youâ€™re right. I was following the original implementation on a perceptron too closely.

It seems that the model updates correctly on vanilla Pytorch, but it doesnâ€™t on Pytorch Lightning. If you work with the latter, I made a notebook for you to reproduce the error, but otherwise I will post this in the PL community. Thanks again!