Simple custom loss with exception on backward

Hi,

I am writing a custom loss function that minimizes the values based on some intermediate matrix. Simply model returns 3 values. The index of the highest value is the index of the value that should be used from an intermediate matrix (mm), so in short, loss = sum([mm[argmax(pred)] for pred in output]). I wrote such a custom function, and the backward steps raise an exception.

MWE:

def my_loss(inp, target):
    loss = F.softmax(inp, dim=1)
    loss = torch.argmax(loss, dim=1)
    loss = torch.gather(target, 1, loss.view(-1, 1))  # Gather the corresponding values from target
    loss = loss.sum()
    return loss

# values used to calculate loss
mm = torch.tensor([
    [0.3, 0.8, 1.3],
    [0.1, 0.2, 0.4]
])

# input
inp = torch.rand((2, 5)) # 2 sameples x 5 features

# model
model = nn.Linear(5, 3)

# returns 3 values where the argmax is the index of value in mm that should be used to calculate the loss
pred = model(inp)

# calculate loss
loss = my_loss(pred, mm)

loss.backward()

Error:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

The issue is probably with the gather function, as it takes the target for gradient calculation. Any idea how to solve it?

Hi adm!

Your problem is that argmax() returns an integer, so it isn’t differentiable and
using it “breaks the computation graph.”

The following may or may not make sense for your specific use case:

softmax(), conceptually, might be better named “softargmax()” (or even
“soft-one-hot-encoded-argmax()”). You could consider using:

loss = (F.softmax (inp, dim = 1) * target).sum()

as a differentiable “soft” version of your loss function. You could introduce a
parameter, loss = (F.softmax (alpha * inp, dim = 1) * target).sum(),
where larger values of alpha sharpen (make “harder”) the “softargmax()” in
your loss function.

Best.

K. Frank