Alright so this looks like such a simple issue that I’m afraid I’ve misunderstood some of the core workings of PyTorch/autograd.

My problem is straightforward: I have a gradient of a sum that turns out 0 while it shouldn’t be, and it doesn’t match the sum of the gradients separately.

I’ve concocted a simplified example to illustrate. We’re in the probability domain, or rather, the logprob domain, so this is the straightforward approximation of probabilities using a softmax on top of a linear layer. Well, not quite a softmax, because logprobs. There are two output classes.

``````import torch
import torch.nn as nn
import torch.nn.functional as F

feat_dim, out_dim = 1, 2
encoder = nn.Linear(feat_dim, out_dim, bias=False)
nn.init.zeros_(encoder.weight.data)
# Fabricate some one-hot encoded data of batch size 1, good enough
batch_size = 1
input = torch.eye(feat_dim)[:batch_size]
``````

As I said, I have the sum of gradients, or rather a difference really. Below I’ll show the gradients of the linear layer’s weights after backpropagating first the difference, then the summands themselves separately.

``````logits = encoder.forward(input)
norm = torch.logsumexp(logits, dim=1, keepdim=True)
(logits-norm).sum().backward()

logits = encoder.forward(input)
logits.sum().backward()

logits = encoder.forward(input)
norm = torch.logsumexp(logits, dim=1, keepdim=True)
norm.sum().backward()
``````

This gives (annotated):

``````# grad of logits - norm
tensor([[0.],
[0.]])
tensor([[1.],
[1.]])
tensor([[0.5000],
[0.5000]])
``````

The gradients for `logits` and `norm` are exactly what I expected. I expected the gradient for `logprobs` to be the difference of the gradients of `logits` and `norm`, but that’s not the case. Instead it’s zeros, for no conceivable reason.

I must be missing something incredibly simple. Maybe the gradient of `norm` gets backpropagated multiple times, but I don’t see it. Could anyone shed some light on the issue?

You can you look at your situation as finding derivative of f(g(h(w))), and the values you get are h’ = 1, g’=1/2 and f’=0 . Since this is composition, f’!=h’+g’

To see why specifically you get zero you can decorate your graph to print backward values

``````import torch
import torch.nn as nn
import torch.nn.functional as F

feat_dim, out_dim = 1, 2
encoder = nn.Linear(feat_dim, out_dim, bias=False)
nn.init.zeros_(encoder.weight.data)
batch_size = 1
input = torch.eye(feat_dim)[:batch_size]

def decorate(tensor, name):
tensor.register_hook(print_backward)

logits = encoder.forward(input)
norm = torch.logsumexp(logits, dim=1, keepdim=True)
decorate(logits, 'logits')
decorate(logits, 'norm')
norm.expand_as(logits).backward(torch.ones_like(logits))
``````

This prints following

``````(logits):  tensor([[1., 1.]])
(norm):  tensor([[1., 1.]])
``````

Since both logits and norm vectors are 1’s, their difference will cancel out. The following is an illustration of how 1’s come about

1 Like

That is a beautifully crafted reply @Yaroslav_Bulatov. Many thanks for the informative graphic! And the illustration of how to use `register_hook`, that should come in useful in hte future.

It shows what really happens: PyTorch has summed up the partial derivatives over the different outputs. Which is why you’ve got 0.5 in the graphic, but only a 1D vector of derivatives in the backward hook.

I would still argue my original example of `logits - norm` is a sum however. And if I were to look at the partial derivatives individually, that sum would match. The problem was that I was looking at the sum of the derivatives (summed over outputs), which is what PyTorch gives us.

Anyway, cheers!