What's wrong about the grad of softmax when I use just some of the inputs to do softmax?

I use softmax and then I use log to mimic cross entropy when I try to implement knowledge distillation.

soft_old_output = soft_target.clone()/T
soft_new_output = outputs.clone()/T
soft_new_output.retain_grad()
outputs_S = F.softmax(soft_new_output,dim=1)
outputs_S.retain_grad()
outputs_T = F.softmax(soft_old_output,dim=1)
loss2 = outputs_T.mul(-1*torch.log(outputs_S))
loss2 = loss2.sum(1)
loss2 = loss2.mean()
loss = loss2
loss.backward(retain_graph=True)
print(soft_new_output.grad)

At first the soft_old_output is the same as soft_new_output . So the grad of soft_new_output is all zero. It’s right.
But when I try to use some of the value to do is, like that:

soft_old_output = soft_target.[:,:2]clone()/T
soft_new_output = outputs[:,:2].clone()/T
soft_new_output.retain_grad()
outputs_S = F.softmax(soft_new_output,dim=1)
outputs_S.retain_grad()
outputs_T = F.softmax(soft_old_output,dim=1)
loss2 = outputs_T.mul(-1*torch.log(outputs_S))
loss2 = loss2.sum(1)
loss2 = loss2.mean()
loss = loss2
loss.backward(retain_graph=True)
print(soft_new_output.grad)

the grad of soft_new_output will be

tensor([[-5.8213e-08, -1.3913e-09]])

Why aren’t they zero?

The small absolute error is most likely created by the limited numerical precision using flaot32 and you should get a smaller error using float64.
Also note, that using torch.log(torch.softmax(...)) is numerically less stable than F.log_softmax, as the latter applies the log-sum-exp trick to increase the stability.

Thank you very much~
And I have another question now, when I try to find the problem myself, I try to track gradient changes. But I don’t know where to find the grad_fn of each variable. So where can I find them? I mean, I want to know how they work so that I can analyze which part leading to this problem, but I can’t find them.

The grad_fn is an attribute of a tensor, which was created by a differentiable operation and which Autograd has tracked.
Here is a small example:

x = torch.randn(1)
w = torch.randn(1, requires_grad=True)
y = x * w
print(y.grad_fn)
> <MulBackward0 object at ...>

I know that but what does MulBackward0 do? It is easy to understand it but what about other grad_fn? Do they just work like the mathematical formula says? Or do they have any trick to work? So I want to find the code of them to have a look.

Yes, they should be implemented using the right mathematical derivatives.
You can find the implementations in derivatives.yaml, where these methods are either directly defined or are pointing to the name of the implementation.

That’s just what I need. Thank you very much. :+1: