Sum is zero gradient still flowing when multiplied with one hot vec

In order to explain my issue, I’ll just show step by step:
First let’s build a simple model:

class Model(nn.Module):
    def __init__(self):
      super(Model,self).__init__()
      self.fc = nn.Linear(10,2)
    def forward(self,x):
      return self.fc(x)

Let’s declare the model ,optimizer and a random input:

model = Model()
optimizer = torch.optim.SGD(model.parameters(),lr=1)
x = torch.rand(10)

Let’s run the model:

out = model(x)

We are not going to use any known loss function but we will do the following:

prob1 = F.softmax(out,dim=-1)
prob2 = F.softmax(out,dim=-1)
loss = prob1 - prob2.detach()
loss = torch.sum(loss)
loss.backward()

The loss suppose to be zero, the gradients should flow through one of the softmax.
Let’s print the loss and the gradient of fc.w:

loss: tensor(0., grad_fn=)
gradient: tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

At least for me, it make sense.
But, Let’s make a small change, we’ll add a target as follows:

target = torch.LongTensor([1])
target_onehot = F.one_hot(target,2)
prob1 = F.softmax(out,dim=-1)
prob2 = F.softmax(out,dim=-1)
loss = prob1 - prob2.detach()
loss = torch.sum(loss*target_onehot)
loss.backward()

Now let’s run and print again:

sum_loss: tensor(0., grad_fn=)
gradient: tensor([[-0.1637, -0.1643, -0.2050, -0.2044, -0.1300, -0.0278, -0.2135, -0.0859,
-0.2186, -0.0658],
[ 0.1637, 0.1643, 0.2050, 0.2044, 0.1300, 0.0278, 0.2135, 0.0859,
0.2186, 0.0658]])

The target is multiplied by loss which is a vector of zeros, how can it be that now we have gradients?
even if… why the multiplication with “target” helps? he is not an important part of the backward graph?
What am I missing?
Thanks a lot!!!

Hi,

The fact that the loss is 0 does not mean the gradient is 0.
The gradient that flows back from the backward() call is actually 1 (because it computes vector jacobian product. And so when the vector is just 1, you get the jacobian).
The big difference between the two is that in the first case, the gradient that flows back to the softmax is a vector full of 1 while in the second case, that vector has been multiplied by target_onehot (backward of the multiplication) and so contains way less 1s and is not symetric anymore. I guess this is where these non-0 values come from.

I see your point but my calculations are just not agreeing with that.
I’ll make things a bit clearer
I’ll change the notations to:

loss1 = torch.sum(loss*target_onehot)

And let’s denote:
S(out) = softmax(out)

Now:
dloss1/dloss = [ target_onehot[0] , target_onehot[1]]
dloss/dout=[[S(out[0])S(out[0]),S(out[0])(1-S(out[1])) ],
[S(out[1])(1-S(out[0])),S(out[1])S(out[1])]]
dout/db = [[1 , 1] , [1, 1]]
dloss1/db = dloss1/dloss * dloss/dout * dout/db
If dloss1/dloss is:[1,1] -> I don’t see how it make the grad equal zero. (unless for very specific S() values.)

You can simply look at dloss1/dloss * dloss/dout
Here you do a product of matrices: [1, 1] * [[S0(1-S0), -S0*S1], [-S1*S0, S1(1-S1)]].
And so the result is: [S0(1 - S0 - S1), S1(1-S1-S0)]. But because your softmax has only two entries, we know that S0 + S1 = 1. So the result for the gradient is [0, 0].

1 Like

Thanks I forgot that detail, and I had a big mistake in the softmax derivative.
Thanks again.

1 Like