Cost function using Torch.argmin()

Hi everyone,

I’m trying to use a custom loss function which is something like:
torch.dist(torch.argmin(predictedData).double(), actualData, p=2)
loss.backward()
This function outputs a tensor with requires_grad = False and therefore when I try to backprop I get:
RuntimeError(‘element 0 of tensors does not require grad and does not have a grad_fn’
*I think it’s because torch.argmin() output has requires_grad = False as well

Does anyone knows how can I fix this issue?

Create an variable in pytorch with require_grad is True then add this loss to that variable. Then use the created variable to take the backprop

Hi Ilya!

First off, I suspect that you don’t* really want to do what you think
you want to do.

However, making some guesses about what (you think) you want,
you could try something like this.

The key issue is that argmin maps onto discrete integers, and so,
is not usefully differentiable. So you can’t backpropagate through it.

(Note that Jaya’s suggestion of adding your non-differentiable loss
to a requires_grad = True variable won’t work. It will get rid of
your specific error message, but it’s a fake out. Gradients won’t
backpropagate through your loss nor through predictedData nor
through the model you used to predict predicted data, so you can’t
use it to train your model parameters.)

Let me assume that predictedData is a vector of numbers between
0 and 1. (For simplicity, I will illustrate this without the (necessary)
batch dimension.) I will also assume that actualData is a single,
integer “index”. I think you want the index of your smallest probability
(argmin) to match actualData, and, if it doesn’t, you want your loss
to measure how far away it is in “index space.”

To do this in a differentiable, backpropagateable , you have to get
rid of the argmin.

One approach would be to pre-compute the “index-space” distances,
and not discretize your predictedData:

import torch
torch.__version__

torch.manual_seed (2020)

nClass = 5
tt = torch.FloatTensor (range (nClass)).repeat (nClass).view (nClass, nClass)
dst = (tt - tt.transpose (0, 1)).abs()
dst

actualData = 2
actualData

predictedDataA = torch.ones (nClass)
predictedDataA[2] = 0.0
predictedDataB = torch.ones (nClass)
predictedDataB[4] = 0.0
predictedDataC = torch.rand (nClass)
predictedDataA
predictedDataB
predictedDataC

lossA = ((1.0 - predictedDataA) * dst[actualData]).sum(dim = 0)
lossB = ((1.0 - predictedDataB) * dst[actualData]).sum(dim = 0)
lossC = ((1.0 - predictedDataC) * dst[actualData]).sum(dim = 0)
lossA
lossB
lossC

Here is the output:

>>> import torch
>>> torch.__version__
'0.3.0b0+591e73e'
>>>
>>> torch.manual_seed (2020)
<torch._C.Generator object at 0x000002CE49626630>
>>>
>>> nClass = 5
>>> tt = torch.FloatTensor (range (nClass)).repeat (nClass).view (nClass, nClass)
>>> dst = (tt - tt.transpose (0, 1)).abs()
>>> dst

 0  1  2  3  4
 1  0  1  2  3
 2  1  0  1  2
 3  2  1  0  1
 4  3  2  1  0
[torch.FloatTensor of size 5x5]

>>>
>>> actualData = 2
>>> actualData
2
>>>
>>> predictedDataA = torch.ones (nClass)
>>> predictedDataA[2] = 0.0
>>> predictedDataB = torch.ones (nClass)
>>> predictedDataB[4] = 0.0
>>> predictedDataC = torch.rand (nClass)
>>> predictedDataA

 1
 1
 0
 1
 1
[torch.FloatTensor of size 5]

>>> predictedDataB

 1
 1
 1
 1
 0
[torch.FloatTensor of size 5]

>>> predictedDataC

 0.4869
 0.1052
 0.5883
 0.1161
 0.4949
[torch.FloatTensor of size 5]

>>>
>>> lossA = ((1.0 - predictedDataA) * dst[actualData]).sum(dim = 0)
>>> lossB = ((1.0 - predictedDataB) * dst[actualData]).sum(dim = 0)
>>> lossC = ((1.0 - predictedDataC) * dst[actualData]).sum(dim = 0)
>>> lossA

 0
[torch.FloatTensor of size 1]

>>> lossB

 2
[torch.FloatTensor of size 1]

>>> lossC

 3.8152
[torch.FloatTensor of size 1]

*) Please give us some context about the problem you’re trying to
solve. It’s likely that your approach doesn’t make sense. The main
reason is that order of the values in predictedData won’t generally
have any quantitative meaning (unless you train it to, but that’s
another story). So “distance” in “index space” doesn’t really have
any meaning. Asking whether argmin matches actualData can
make sense, but asking how far away it is if it doesn’t match probably
doesn’t. The test against argmin suggests that actualData is an
integer categorical label, and suggests that you might be working on
some kind of a classification problem where CrossEntropyLoss might
be appropriate.

Best.

K. Frank

Hi Frank,

Thanks for the detailed answer, I will update after I’ll give it a try,
I think that cross entropy loss is a good idea

  • I thought about this issue regarding argmin()…

some context:
I’m trying to build a network which will predict an inverse function for a channel distortion,
I simulate a signal and pass it trough a random FIR filter (channel distortion function), later I pass this data and the actual data to the network. The network outputs filter coeffs that are suppose to inverse the FIR response.
*the filter is lfilter from torch audio
The argmin() suppose to be used as a detector, I calculate the distance between the reconstructed signal and the initial code signal, argmin is the detected code.
example: symMap = pd.Series(data=[-1, 1j, -1j, 1], index=[0, 1, 2, 3])
than if the reconstructed signal is -0.5 the detected data will be 0.

Thanks,
Ilya

Update:
for now, I used torch.dist( ) without argmin() - on the predetected signal
this seems to work and the model is training

Thanks Frank for the advice