Alright so this looks like such a simple issue that I’m afraid I’ve misunderstood some of the core workings of PyTorch/autograd.
My problem is straightforward: I have a gradient of a sum that turns out 0 while it shouldn’t be, and it doesn’t match the sum of the gradients separately.
I’ve concocted a simplified example to illustrate. We’re in the probability domain, or rather, the logprob domain, so this is the straightforward approximation of probabilities using a softmax on top of a linear layer. Well, not quite a softmax, because logprobs. There are two output classes.
import torch import torch.nn as nn import torch.nn.functional as F feat_dim, out_dim = 1, 2 encoder = nn.Linear(feat_dim, out_dim, bias=False) nn.init.zeros_(encoder.weight.data) # Fabricate some one-hot encoded data of batch size 1, good enough batch_size = 1 input = torch.eye(feat_dim)[:batch_size]
As I said, I have the sum of gradients, or rather a difference really. Below I’ll show the gradients of the linear layer’s weights after backpropagating first the difference, then the summands themselves separately.
logits = encoder.forward(input) norm = torch.logsumexp(logits, dim=1, keepdim=True) (logits-norm).sum().backward() print(encoder.weight.grad) nn.init.zeros_(encoder.weight.grad.data) logits = encoder.forward(input) logits.sum().backward() print(encoder.weight.grad) nn.init.zeros_(encoder.weight.grad.data) logits = encoder.forward(input) norm = torch.logsumexp(logits, dim=1, keepdim=True) norm.sum().backward() print(encoder.weight.grad)
This gives (annotated):
# grad of logits - norm tensor([[0.], [0.]]) # grad of logits tensor([[1.], [1.]]) # grad of norm tensor([[0.5000], [0.5000]])
The gradients for
norm are exactly what I expected. I expected the gradient for
logprobs to be the difference of the gradients of
norm, but that’s not the case. Instead it’s zeros, for no conceivable reason.
I must be missing something incredibly simple. Maybe the gradient of
norm gets backpropagated multiple times, but I don’t see it. Could anyone shed some light on the issue?