I’ve been trying to understand more about autograd and how the gradients are being computed for the backward pass. Softmax, however, is one of those interesting functions that has a complex gradient in which you have to compute the Jacobian for each set of features softmax is applied to where the diagonal is s(1 - s) and the off diagonal is -s * s’ where s != s’ and s is the softmax output probabilities. This is when it’s just a single vector, but I thought that this generalizes by just flattening the higher rank input tensor, computing the jacobian, summing the gradients for each example (compacting the the columns), then reshaping the output to the original input s.

However, I’m finding out this is not how autograd is computing it. Here is what I’m trying to understand. Is there away to explain the backward call for softmax without cross entropy loss?

import torch
import torch.nn.functional as f
def main():
# trying to represent a batch with 5 samples of 4 features
x = torch.arange(20).reshape(5, 4)
x = x.float().requires_grad_()
# Jacobian computation
def softmax_grad(probs):
tensor = probs.clone().detach()
flat = torch.flatten(tensor)
diagonal = torch.diagflat(flat)
off_diagonal = torch.outer(flat, flat)
return diagonal - off_diagonal
probs = f.softmax(x, dim=-1)
grad = torch.ones_like(probs)
probs.backward(grad)
jacobian = softmax_grad(probs)
x_grad = torch.sum(jacobian, dim=-1, keepdim=True).reshape(x.size()) * grad
print(f"What I expected:\n{x_grad}\n")
print(f"What autograd computed:\n{x.grad}")
if __name__ == "__main__":
main()

When you “just flatten” the input tensor, you lose track of where the rows
are (which individually are probability distributions across the features
and sum to one) and, in effect, have one long row (which no longer sums
to one). As a result of flattening the batch and features dimensions
together, you introduce spurious non-zero elements into your putative
jacobian.

Just to be clear, you do expect this gradient to be zero. This is because probs.backward (torch.ones_like (probs)) is the same as probs.sum().backward() and because the output of softmax() sums
(row-wise) to one, you are taking the gradient of a constant. So that
gradient is zero.

The following script illustrates these two points:

import torch
print (torch.__version__)
_ = torch.manual_seed (2024)
def softmax_grad(probs):
tensor = probs.clone().detach()
flat = torch.flatten(tensor)
diagonal = torch.diagflat(flat)
off_diagonal = torch.outer(flat, flat)
return diagonal - off_diagonal
y = torch.arange (4.) # no batch dimension
probs_y = y.softmax (dim = -1)
print ('y = ...')
print (y)
print ('probs_y = ...')
print (probs_y)
jacyA = softmax_grad (probs_y) # works with no batch dimension
jacyB = torch.autograd.functional.jacobian (torch.nn.Softmax (dim = -1), y)
print ('jacyA = ...')
print (jacyA)
print ('torch.allclose (jacyA, jacyB) =', torch.allclose (jacyA, jacyB))
x = torch.arange (20).reshape (5, 4) # has batch dimension
x = x.float().requires_grad_()
probs_x = x.softmax (dim = -1)
print ('x = ...')
print (x)
print ('probs_x = ...')
print (probs_x)
jacxA = softmax_grad (probs_x) # incorrectly flattens batch and features dimensions together
jacxB = torch.autograd.functional.jacobian (torch.nn.Softmax (dim = -1), x)
print ('jacxA.shape =', jacxA.shape) # wrong shape for jacobian of 2d tensor
print ('jacxB.shape =', jacxB.shape)
jfA = jacxA.flatten() # reshaping doesn't help
jfB = jacxB.flatten()
print ('(jfA - jfB).abs().max() =', (jfA - jfB).abs().max())
print ('(jfA == 0).sum() =', (jfA == 0).sum()) # too many non-zero elements (but ...)
print ('(jfB == 0).sum() =', (jfB == 0).sum())
print ('torch.allclose (jfA[torch.where (jfB != 0)], jfB[torch.where (jfB != 0)]) =', torch.allclose (jfA[torch.where (jfB != 0)], jfB[torch.where (jfB != 0)]))
z = torch.randn (5, 4, requires_grad = True) # has batch dimension
probs_z = z.softmax (dim = -1)
print ('z = ...')
print (z)
print ('probs_z = ...')
print (probs_z)
print ('probs_z.sum (dim = -1) = ...')
print (probs_z.sum (dim = -1)) # probs sum to one, a constant, for each row in batch
probs_z.sum().backward() # sum over batch, as well as features, to get a scalar
gradzA = z.grad # gradient of sum of softmax is zero because sum is constant
jaczB = torch.autograd.functional.jacobian (torch.nn.Softmax (dim = -1), z)
gradzB = jaczB.sum (dim = (-2, -1)) # sum of partials (jacobian) is derivative (gradient) of sum
print ('gradzA = ...')
print (gradzA)
print ('torch.allclose (gradzA, gradzB, atol = 1.e-7) =', torch.allclose (gradzA, gradzB, atol = 1.e-7))

Here is my interpretation: we have to compute the Jacobian matrix for each value in the input array that had Softmax applied to it. So if we have our example (5, 4), we’d have five 4 x 4 Jacobians when Softmax(dim=-1)? Does this same concept generalize to higher rank tensors? I’m hesitant to make that claim given my last generalization led me to creating this post. Also, I see you’re using autograd.functional to help you with the Jacobian computation, but I’m interesting in how this could be done without the module. More specifically, is there a vectorized (with numpy) way to compute the gradients (Jacobian) for rows that Softmax was applied to? And lastly, when computing the gradient for an input, is our goal is to essentially compress the rows of each examples Jacobian (to capture all the gradient information for that particular class)? I ask because it appears you’re doing in your batched example for the (5,4,5,4) Jacobian matrix.

Yes, or more precisely, for each row that has softmax() applied to it.
If such a row has length 4, the jacobian will have shape [4, 4].

(You could compute elements of the entire jacobian that correspond to
elements of one row and elements of another, but these terms will be
zero because softmax() doesn’t mix the different rows together.)

Yes.

Yes.

Yes (and you can use pytorch instead of numpy if you prefer).

(I don’t follow what you’re asking here. What do you mean by compress?)

Here is an example that shows how to tweak your computation to work
with batches and multi-dimensional “batches”:

import torch
print (torch.__version__)
_ = torch.manual_seed (2024)
def softmax_grad(probs): # your original version
tensor = probs.clone().detach()
flat = torch.flatten(tensor)
diagonal = torch.diagflat(flat)
off_diagonal = torch.outer(flat, flat)
return diagonal - off_diagonal
def softmax_jacB (probs): # your version adapted for "batch" dimension(s)
probs = probs.clone().detach()
shape_out = list (probs.size()) + [probs.size (-1)] # shape of tensor of jacobians
if probs.ndim > 1:
probs = probs.flatten (end_dim = -2) # flatten all but the last dimension into a single "batch" dimension
diagonal = probs.diag_embed() # form batch of diagonal matrices
off_diagonal = probs.unsqueeze (-1) @ probs.unsqueeze (-2) # form batch of outer products
batch_jac = diagonal - off_diagonal # batch of jacobians
return batch_jac.reshape (shape_out) # unflatten "batch" dimension(s)
y = torch.randn (4) # no batch dimension
probs_y = y.softmax (dim = -1)
print ('y = ...')
print (y)
print ('probs_y = ...')
print (probs_y)
jacyA = softmax_grad (probs_y) # works only with no batch dimension
jacyB = softmax_jacB (probs_y) # works with or without "batch" dimensions
print ('jacyA = ...')
print (jacyA)
print ('torch.allclose (jacyA, jacyB) =', torch.allclose (jacyA, jacyB))
x = torch.randn (3, 5, 4) # has multiple "batch" dimensions
x = x.float().requires_grad_()
probs_x = x.softmax (dim = -1)
print ('x[0, 0, :] = ...') # just look at the [0, 0] element of the "batch"
print (x[0, 0, :])
print ('probs_x[0, 0, :] = ...')
print (probs_x[0, 0, :])
jacxA = torch.autograd.functional.jacobian (torch.nn.Softmax (dim = -1), x)
jacxB = softmax_jacB (probs_x) # works with multiple "batch" dimensions
print ('jacxA.shape =', jacxA.shape) # shape includes block-off-diagonal zero terms
print ('jacxB.shape =', jacxB.shape) # no block-off-diagonal terms
jacxA = jacxA.diagonal (dim1 = 0, dim2 = 3).diagonal (dim1 = 0, dim2 = 2).permute (2, 3, 0, 1)
print ('jacxA.shape =', jacxA.shape) # block-off-diagonal terms removed
print ('torch.allclose (jacxA, jacxB) = ...')
print (torch.allclose (jacxA, jacxB))

First and foremost, I really appreciate the level of detail you’ve provided with the examples. They’ve really helped me understand what’s going with the derivation. I don’t really have any further questions about computing the Jacobian for Softmax applied to higher rank tensors, however to answer your question, I was wondering how we get the gradient for the original input tensor from the Jacobian (this is what I meant by “compress”). In particular to get the gradient of x with shape (3, 5, 4), we’d take its Jacobian and sum over the last dimension i.e. (jacxB.sum(dim=-1, keepdims=True) to accumulate each gradient of the four classes for each of the five rows for each of the three examples in x?

Let me clarify some terminology: The jacoobian is the matrix of (so-called
mixed) partial derivatives of a vector-valued function with respect to its
vector-valued argument. The gradient is the vector of partial derivatives of
a scalar-valued function with respect to its vector-valued argument.

And to use the language carefully, we don’t “get the gradient ofx,” we get
the gradient (of a scalar-valued function of x) with respect tox.

So in order to “get a gradient,” you have to specify the scalar whose
gradient you want to get.

If you “compress” the jacobian with jacxB.sum (dim = -1), you will
get a batch of gradients (because jacB is explicitly a batch of jacobians,
rather than just a single jacobian). But what (batch of) scalar(s) are you
getting the gradient(s) of?

The sum of the derivatives is the derivative of the sum.

So if you think of probs_x = x.softmax (dim = -1) as a batch of sets of
four individual probabilities, then you can think of probs_x.sum (dim = -1)
as a batch of scalars, each of which is the sum of four probabilities.

The probabilities computed by softmax() sum to one, a constant, so in
this case, after you sum over the rows of jacxB you will get a batch of zeros
(up to round-off error), because the sum gives you a batch of gradients of a
batch of scalars that are all the constant one.

To help understand the difference between jacobians and gradients and
how autograd turns conceptual jacobians into concrete gradients as it
backpropagates up through the layers, take a look at the documentation
for autograd.grad(), its grad_outputs argument, and pytorch’s concept
of vector-jacobian product, as in this quote from the above-referenced
documentation:

grad_outputs (sequenceofTensor) – The “vector” in the vector-Jacobian product. Usually gradients w.r.t. each output.

After doing some digging and some further reading, I think I’m starting to get your point. The Jacobian is a matrix where each row holds the partial derivatives of a particular output w.r.t. to each input to the original vector-valued function. The gradient can be seen as just a vector containing the partial derivatives of a scalar producing multivariate function (loss function most often in ML) w.r.t. to every single variable passed as parameters to it.

In our particular example for the higher rank Tensor, when we compute jacxB.sum(dim=-1), we obtain the gradient for each of the five scalar sums over a batch of three of them produced from probs_x.sum(dim=-1); the collection of each individual gradient are the gradients that we pass throughout the model during backpropagation (pretending this is a neural network). Am I in the ballpark here?

Now, the reason we’re practically getting a bunch of zeros in the gradients for probs_x.sum(dim=-1) w.r.t. x (this is equivalent to summing the Jacobian along the outputs for each Jacobian in the batch of Jacobians) is because we have a function that always produces a constant 1 (i.e., f(x) = softmax(x, dim=-1).sum(dim=-1)), since softmax always sums to 1 for each row in the batch of probabilities. So summing each row of each Jacobian will pretty much just give us zeros for all the gradients of our function w.r.t. to x; changing an input in any way won’t change the output of f(x).

This behavior can be seen more clearly when you compute a single row of a Jacobian and sum it. e.g., s_i(1 - s_i) - s_i(s_j) - s_i(s_n) = 0 (i is one particular class, j is a different class, and n is the total number of classes).