Jacobian of Cross Entropy Loss

I’m trying to find the Jacobian of nn.CrossEntropyLoss, where my inputs are some network output and target in the form of class indeces. I tired doing a simple test case as follows:

``````import torch.nn as nn

loss = nn.CrossEntropyLoss()

target = torch.empty(3, dtype=torch.long).random_(5)

``````

This throws RuntimeError: only Tensors of floating point dtype can require gradients

I tired converting target to a float tensor but that doesn’t work with the Cross Entropy, as it wants target to be either int or long. Is there any way around this? I’m new to pytorch so I apologize if I’m thinking about this completely wrong. Overall I want to be able to do forward mode AD on the loss so that I can do a directional derivative/jacobian vector product in the direction of some vector v, or in this case (since Cross Entropy outputs a scalar) the magnitude of the projection of the gradient along v.

Hi Jean!

The Jacobian matrix is the matrix of partial derivatives of a number of
results (the vector-valued result) of a function with respect to a number
of inputs (a vector-valued input) to that function.

Could you specify which partial derivatives you want to compute, that is
the derivatives of which specific results with respect to which specific
inputs?

Your code corresponds to the most common use case of `CrossEntropyLoss`,
namely, where the `target` is a batch of integer class labels. It doesn’t make
sense to differentiate something with respect to an integer, because it doesn’t
make sense to vary an integer value an infinitesimal amount away from being
an integer.

In your Jacobian computation, are you intending to differentiate something
with respect to your integer `target`? If so, what is that supposed to mean,
conceptually and mathematically?

Best.

K. Frank

I was just looking to find the gradient of loss with respect to input, like would normally be done with backprop using loss.backward(). My hope was to be able to do it via forward mode (for various reasons) but since loss takes both input and target I wasn’t sure if there was an easy way to split it up or get around this error message.

Hi Jean!

First a general answer, and then a work-around for your use case:

For context:

The problem is that you want to apply `jacobian()` to a function one of
whose arguments (`target`) is an integer tensor – not differentiable – hence
`jacobian()` complains. But you only want gradients of the floating-point,
differentiable argument (`input`), which is, in principle, okay.

As a general approach, define a function-object class. When you instantiate
it, pass it `target`, and have it store `target` as a property. Have its `__call__`
method only take `input` (and `self`) as an argument. Now `jacobian()` will
only try to calculate gradients of the floating-point `input`.

Thus:

``````import torch
print (torch.__version__)

_ = torch.manual_seed (2022)

class CrossEntropyWrapper:
def __init__ (self, target):
self.target = target
def __call__ (self, input):
return  torch.nn.functional.cross_entropy (input, self.target)

input = torch.randn (3, 5, requires_grad = True)
target = torch.empty (3, dtype = torch.long).random_ (5)

wrapper = CrossEntropyWrapper (target)

print (jbA)
``````

resulting in:

``````1.11.0
tensor([[-0.2766,  0.0651,  0.0589,  0.1144,  0.0382],
[-0.3239,  0.0023,  0.1199,  0.1836,  0.0181],
[ 0.0190,  0.0239,  0.0961,  0.1055, -0.2445]])
``````

For your specific use case – because (as of a few versions ago)
`CrossEntropyLoss` can also take floating-point “soft” targets (of
a different shape) – you can let `jacobian()` calculate the gradients
of both `input` and a floating-point version of `target`, and you can
just ignore the gradients of `target`.

Continuing:

``````loss = torch.nn.CrossEntropyLoss()

target_onehot_float = torch.nn.functional.one_hot (target, 5).float()

print (jbB)
``````

with the result:

``````(tensor([[-0.2766,  0.0651,  0.0589,  0.1144,  0.0382],
[-0.3239,  0.0023,  0.1199,  0.1836,  0.0181],
[ 0.0190,  0.0239,  0.0961,  0.1055, -0.2445]]), tensor([[0.5905, 0.5442, 0.5775, 0.3565, 0.7225],
[1.1887, 1.6564, 0.3409, 0.1988, 0.9707],
[0.9557, 0.8783, 0.4147, 0.3834, 0.4406]]))
``````

Best.

K. Frank