# How to not loose the gradient on this operation?

Hey,

I have a model that outputs batches with size B of “A-dimensional” softmax probabilities for C classes. For instance, the output `w` could look like this, for B=2, A=4 and C=3:

``````# B x A x C
w = torch.softmax(torch.rand([2,4,3]),dim=2)
w

# standard A-dimensional class probabilities, shape: B x A x C
>> tensor([[[0.2901, 0.4441, 0.2658],
[0.4681, 0.2511, 0.2808],
[0.2814, 0.3119, 0.4067],
[0.4007, 0.3358, 0.2635]]
[[0.2750, 0.4774, 0.2476],
[0.3758, 0.3680, 0.2562],
[0.3018, 0.3007, 0.3976],
[0.3610, 0.3699, 0.2691]]])
``````

where the probs sum up along the class dimension.

In my loss function, I need to construct the labels / targets `y`in dependence of the input `w`, i.e. `y = y(w)`, before feeding them together into `nn.CrossEntropyLoss()(w.transpose(1,2),y)`. In order to do that, I need to convert `w` to a representation that shows the class numbers instead of the softmax probabilities. I do it like this:

``````# first step
wa = (1*(w == torch.max(w,dim=(2))[0].unsqueeze(-1))) # B x A x C
wa

# intermediate output
# (indicates the location of the  maximum probabilities)
>> tensor([[[0, 1, 0],
[1, 0, 0],
[0, 0, 1],
[1, 0, 0]],
[[0, 1, 0],
[1, 0, 0],
[0, 0, 1],
[0, 1, 0]]])

# second step
# (converts the intermediate output vectors to the predicted class labels)
wp = torch.max(wa,dim=2)[1]
wp

# class number representation, shape: B x A (with dimension C dropped)
>> tensor([[1, 0, 2, 0],
[1, 0, 2, 1]])
``````

I would like to achieve that my final labels that are constructed by further processing `wp` have a gradient with respect to `w`. This will be the case if `wp` has a gradient wrt `w`. However, the second operation drops the gradient, because it asks for the max_indices (list element [1]) of `wa`, which are not differentiable, I think. Or in other words, the second operation is an argmax operation, which does not have a gradient, afaic.

Is there any way to rewrite the operations (i.e some smart way to circumvent the argmax operation) such that the gradient of `wp` wrt `w` is somehow preserved? I did not find one, yet.

Thanks!

Best, JZ

Hi Jay!

The short story is that the gradient you are asking for isn’t useful because
it is zero almost everywhere.

Some more detail:

The problem is that your labels, `wp`, are integers (and would have integer
values even if you were to convert them to floating-point numbers). It
doesn’t really make sense to differentiate integers.

Consider `wp[0, 0]`. In the example you give it starts out as `1`. As you
vary the values in `w`, `wp[0, 0]` will stay `1` for a while, and over this
range the gradient will be zero (as `wp[0, 0]` is not changing).

Then, at some point as you vary `w`, `wp[0, 0]` will “pop” to another
value (say to `2`). Right at this point the gradient will be infinity (or,
if you prefer, undefined).

That is, for almost all values of `w`, the gradient will be zero and will
otherwise be infinite (or undefined).

(Because gradients of things like `argmax()` aren’t useful, pytorch
generally doesn’t compute them for you.)

As an aside, it sounds like you want to use as a loss the cross entropy
of some floating-point numbers (`w`) with (integer) class labels (`wp`)
derived from those floating-point numbers. You will have useful gradients
of your loss with respect to `w` (but not `wp`).

I don’t understand your use case, so I can’t say whether gradients
directly with respect to `w` would make sense for you, but they would
be conceptually valid, and pytorch will compute them for you.

Best.

K. Frank

1 Like

Hello K. Frank,