Can not get gradient of the kl_div

Hi, I have some problems of getting the gradient when I used the kl_div function.

I tried two versions of the kl_div function, one is the official version: F.kl_div, and the other is implemented by myself. The results of the official kl_div and my kl_div are same but the gradient of the F.kl_div is always 0.

import torch
import torch.nn as nn
import torch.nn.functional as F


def kl_div(x, y):
    _kl = torch.sum(y * (torch.log(y) - x), -1)
    return torch.mean(_kl)


logits = torch.Tensor([[1, 2, 3, 4]])

logits.requires_grad = True

label = torch.Tensor([[0.1, 0.1, 0.7, 0.1]])
label2 = torch.Tensor([[0.1, 0.1, 0.7, 0.1]])
label.requires_grad = True
y_qry = torch.LongTensor([1])

logsoftlogits = F.log_softmax(logits, dim=-1)

case = 0
if case == 0:
    loss = kl_div(logsoftlogits, label)
    #print(loss)
else:
    loss = F.kl_div(logsoftlogits, label, reduction='sum') / logsoftlogits.size()[0]
    #print(loss)
grad = torch.autograd.grad(outputs=loss, inputs=logits, create_graph=True)
logits_q = logits - grad[0]

loss_q = F.cross_entropy(logits_q, y_qry)

labelgrad = torch.autograd.grad(outputs=loss_q, inputs=label, grad_outputs=torch.ones(loss_q.size()))
print(labelgrad)

the official kl_div result:

(tensor([[0., 0., 0., 0.]]),)

my implementation reusult:

(tensor([[-0.2615, -1.1997, 0.1304, 0.1274]]),)

Hi Ereebay!

I don’t follow fully what you are doing with two layers of losses and grads,
and I will comment in terms of the older 0.3.0 version of pytorch.

I believe what you are seeing is that F.kl_div() doesn’t track gradients
with respect to its second argument (your label). (I think this is a
“feature” in the interest of efficiency.) The version of kl_div() that you
wrote isn’t “smart” enough to not compute the gradient with respect
to its second argument.

Please see this related thread:

Here is a simple pytorch 0.3.0 script that illustrates this for F.kl_div():

import torch
torch.__version__

torch.manual_seed (2020)

# make some log-probabilities and probabilities
logprobs = torch.nn.functional.log_softmax (torch.autograd.Variable (torch.randn (2, 5)), dim = 1).data
probs1 = torch.nn.functional.softmax (torch.autograd.Variable (torch.randn (2, 5)), dim = 1).data
probs2 = probs1.clone()

# make leaf-variable Variables
# input values are log-probabilities
# input  = torch.autograd.Variable (torch.randn (2, 5), is_leaf = True, requires_grad = True), dim = 1)
input = torch.autograd.Variable (logprobs, requires_grad = True)
# target values are probabilities
targ_grad = torch.autograd.Variable (probs1, requires_grad = True)
targ_nograd = torch.autograd.Variable (probs2, requires_grad = False)

loss = torch.nn.functional.kl_div (input, targ_grad)
loss
loss.backward()
input.grad
targ_grad.grad

loss = torch.nn.functional.kl_div (input, targ_nograd)
loss
loss.backward()
input.grad
targ_nograd.grad

And here is the output:

>>> import torch
>>> torch.__version__
'0.3.0b0+591e73e'
>>>
>>> torch.manual_seed (2020)
<torch._C.Generator object at 0x000002904D7C6630>
>>>
>>> # make some log-probabilities and probabilities
... logprobs = torch.nn.functional.log_softmax (torch.autograd.Variable (torch.randn (2, 5)), dim = 1).data
>>> probs1 = torch.nn.functional.softmax (torch.autograd.Variable (torch.randn (2, 5)), dim = 1).data
>>> probs2 = probs1.clone()
>>>
>>> # make leaf-variable Variables
... # input values are log-probabilities
... # input  = torch.autograd.Variable (torch.randn (2, 5), is_leaf = True, requires_grad = True), dim = 1)
... input = torch.autograd.Variable (logprobs, requires_grad = True)
>>> # target values are probabilities
... targ_grad = torch.autograd.Variable (probs1, requires_grad = True)
>>> targ_nograd = torch.autograd.Variable (probs2, requires_grad = False)
>>>
>>> loss = torch.nn.functional.kl_div (input, targ_grad)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: the derivative for 'target' is not implemented
>>> loss
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NameError: name 'loss' is not defined
>>> loss.backward()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NameError: name 'loss' is not defined
>>> input.grad
>>> targ_grad.grad
>>>
>>> loss = torch.nn.functional.kl_div (input, targ_nograd)
>>> loss
Variable containing:
 0.2735
[torch.FloatTensor of size 1]

>>> loss.backward()
>>> input.grad
Variable containing:
1.00000e-02 *
 -0.6291 -1.5081 -5.4165 -1.8988 -0.5475
 -0.8674 -0.3891 -0.4075 -0.6592 -7.6768
[torch.FloatTensor of size 2x5]

>>> targ_nograd.grad
>>>

The pytorch 0.3.0 version of F.kl_div() simply refuses to accept a second
argument that has requires_grad = True, and throws an error. If the
second argument has requires_grad = False, F.kl_div() works, and
you get a gradient for its first argument, but, of course, not its second.

Does this look like the root cause of what you are seeing?

Best.

K. Frank

Thx for your reply, Frank. Sorry for my poor English.

First, let me explain what I’m doing in this two layer, I have a logits which is a result of a classification, and the fake or synthetic label is the target, and I used to kl_dive to compute the divergence between them. Than I used to gradient to update the logits, than I used the new logits to compute the cross entropy between the new logits and the right label(y_qry). After that, I used the new gradient of the synthetic label to update the synthetic label.

I believe what you are seeing is that F.kl_div() doesn’t track gradients
with respect to its second argument (your label ). (I think this is a
“feature” in the interest of efficiency.)

Second, I used the auto.grad function to directly compute the gradient of label, it can track gradients with respect to its arguments normally. But I compute the gradient of label in the cross entropy loss, it can not track. Is this possible that it can not track an tensor in another loss function ?

Hello Ereebay!

As shown in my post above, in pytorch 0.3.0, F.kl_div() simply
does not track the gradient of its second argument (label).

I downgraded, for whatever reason, to pytorch 0.3.0, so I can’t test
this on later versions. Could you tell us what version of pytorch you
are using?

(Perhaps someone with pytorch 1.x could run my test on a more
up-to-date version.)

Could you check in a simple script whether or not you can track gradients
with respect to both arguments of F.kl_dv() in your version of pytorch?
(Don’t use your loss-within-a loss code – just make a single call to
F.kl_div() for testing purposes.)

I don’t know for a fact, but I would not expect this to be the case. As far
as I understand it, loss functions are just regular (autograd-supporting)
tensor functions, so I would expect gradient-tracking to work as well
for your loss-within-a-loss use case as it does for the individual loss
functions separately.

Best.

K. Frank

Hi,Frank, I used pytorch 1.2.0. And I checked the both arguments of F.kl_div() in my version of pytorch, they can be tracked normally. I also met the same problem in the L1 loss function.

Hello Ereebay!

I can’t test this myself on pytorch 1.x, but looking at the code for kl_div()
(specifically line 86, grad_input_val = -target_val * grad_val;),
I only see the gradient being computed for the first (input, your logits)
argument, and not for the second (target, your label).

(I assume that this code is for the current “stable” pytorch version, so
newer than your 1.2.0.)

Could you post your (short, complete, runnable) test script, together
with its output? Maybe somebody with a non-archaic copy of pytorch
could help look into this.

Best.

K. Frank