Hi Ilya!
First off, I suspect that you don’t* really want to do what you think
you want to do.
However, making some guesses about what (you think) you want,
you could try something like this.
The key issue is that argmin
maps onto discrete integers, and so,
is not usefully differentiable. So you can’t backpropagate through it.
(Note that Jaya’s suggestion of adding your non-differentiable loss
to a requires_grad = True
variable won’t work. It will get rid of
your specific error message, but it’s a fake out. Gradients won’t
backpropagate through your loss nor through predictedData
nor
through the model you used to predict predicted
data, so you can’t
use it to train your model parameters.)
Let me assume that predictedData
is a vector of numbers between
0 and 1. (For simplicity, I will illustrate this without the (necessary)
batch dimension.) I will also assume that actualData
is a single,
integer “index”. I think you want the index of your smallest probability
(argmin
) to match actualData
, and, if it doesn’t, you want your loss
to measure how far away it is in “index space.”
To do this in a differentiable, backpropagateable , you have to get
rid of the argmin
.
One approach would be to pre-compute the “index-space” distances,
and not discretize your predictedData
:
import torch
torch.__version__
torch.manual_seed (2020)
nClass = 5
tt = torch.FloatTensor (range (nClass)).repeat (nClass).view (nClass, nClass)
dst = (tt - tt.transpose (0, 1)).abs()
dst
actualData = 2
actualData
predictedDataA = torch.ones (nClass)
predictedDataA[2] = 0.0
predictedDataB = torch.ones (nClass)
predictedDataB[4] = 0.0
predictedDataC = torch.rand (nClass)
predictedDataA
predictedDataB
predictedDataC
lossA = ((1.0 - predictedDataA) * dst[actualData]).sum(dim = 0)
lossB = ((1.0 - predictedDataB) * dst[actualData]).sum(dim = 0)
lossC = ((1.0 - predictedDataC) * dst[actualData]).sum(dim = 0)
lossA
lossB
lossC
Here is the output:
>>> import torch
>>> torch.__version__
'0.3.0b0+591e73e'
>>>
>>> torch.manual_seed (2020)
<torch._C.Generator object at 0x000002CE49626630>
>>>
>>> nClass = 5
>>> tt = torch.FloatTensor (range (nClass)).repeat (nClass).view (nClass, nClass)
>>> dst = (tt - tt.transpose (0, 1)).abs()
>>> dst
0 1 2 3 4
1 0 1 2 3
2 1 0 1 2
3 2 1 0 1
4 3 2 1 0
[torch.FloatTensor of size 5x5]
>>>
>>> actualData = 2
>>> actualData
2
>>>
>>> predictedDataA = torch.ones (nClass)
>>> predictedDataA[2] = 0.0
>>> predictedDataB = torch.ones (nClass)
>>> predictedDataB[4] = 0.0
>>> predictedDataC = torch.rand (nClass)
>>> predictedDataA
1
1
0
1
1
[torch.FloatTensor of size 5]
>>> predictedDataB
1
1
1
1
0
[torch.FloatTensor of size 5]
>>> predictedDataC
0.4869
0.1052
0.5883
0.1161
0.4949
[torch.FloatTensor of size 5]
>>>
>>> lossA = ((1.0 - predictedDataA) * dst[actualData]).sum(dim = 0)
>>> lossB = ((1.0 - predictedDataB) * dst[actualData]).sum(dim = 0)
>>> lossC = ((1.0 - predictedDataC) * dst[actualData]).sum(dim = 0)
>>> lossA
0
[torch.FloatTensor of size 1]
>>> lossB
2
[torch.FloatTensor of size 1]
>>> lossC
3.8152
[torch.FloatTensor of size 1]
*) Please give us some context about the problem you’re trying to
solve. It’s likely that your approach doesn’t make sense. The main
reason is that order of the values in predictedData
won’t generally
have any quantitative meaning (unless you train it to, but that’s
another story). So “distance” in “index space” doesn’t really have
any meaning. Asking whether argmin
matches actualData
can
make sense, but asking how far away it is if it doesn’t match probably
doesn’t. The test against argmin
suggests that actualData
is an
integer categorical label, and suggests that you might be working on
some kind of a classification problem where CrossEntropyLoss
might
be appropriate.
Best.
K. Frank