Simple custom loss with exception on backward

Hi,

I am writing a custom loss function that minimizes the values based on some intermediate matrix. Simply model returns 3 values. The index of the highest value is the index of the value that should be used from an intermediate matrix (mm), so in short, loss = sum([mm[argmax(pred)] for pred in output]). I wrote such a custom function, and the backward steps raise an exception.

MWE:

``````def my_loss(inp, target):
loss = F.softmax(inp, dim=1)
loss = torch.argmax(loss, dim=1)
loss = torch.gather(target, 1, loss.view(-1, 1))  # Gather the corresponding values from target
loss = loss.sum()
return loss

# values used to calculate loss
mm = torch.tensor([
[0.3, 0.8, 1.3],
[0.1, 0.2, 0.4]
])

# input
inp = torch.rand((2, 5)) # 2 sameples x 5 features

# model
model = nn.Linear(5, 3)

# returns 3 values where the argmax is the index of value in mm that should be used to calculate the loss
pred = model(inp)

# calculate loss
loss = my_loss(pred, mm)

loss.backward()
``````

Error:

``````RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
``````

The issue is probably with the `gather` function, as it takes the target for gradient calculation. Any idea how to solve it?

Your problem is that `argmax()` returns an integer, so it isn’t differentiable and
using it “breaks the computation graph.”

The following may or may not make sense for your specific use case:

`softmax()`, conceptually, might be better named “softargmax()” (or even
“soft-one-hot-encoded-argmax()”). You could consider using:

``````loss = (F.softmax (inp, dim = 1) * target).sum()
``````

as a differentiable “soft” version of your loss function. You could introduce a
parameter, `loss = (F.softmax (alpha * inp, dim = 1) * target).sum()`,
where larger values of `alpha` sharpen (make “harder”) the “softargmax()” in