How to not loose the gradient on this operation?


I have a model that outputs batches with size B of “A-dimensional” softmax probabilities for C classes. For instance, the output w could look like this, for B=2, A=4 and C=3:

# B x A x C
w = torch.softmax(torch.rand([2,4,3]),dim=2) 

# standard A-dimensional class probabilities, shape: B x A x C
>> tensor([[[0.2901, 0.4441, 0.2658],
            [0.4681, 0.2511, 0.2808],
            [0.2814, 0.3119, 0.4067],
            [0.4007, 0.3358, 0.2635]]
            [[0.2750, 0.4774, 0.2476],
            [0.3758, 0.3680, 0.2562],
            [0.3018, 0.3007, 0.3976],
            [0.3610, 0.3699, 0.2691]]])

where the probs sum up along the class dimension.

In my loss function, I need to construct the labels / targets yin dependence of the input w, i.e. y = y(w), before feeding them together into nn.CrossEntropyLoss()(w.transpose(1,2),y). In order to do that, I need to convert w to a representation that shows the class numbers instead of the softmax probabilities. I do it like this:

# first step
wa = (1*(w == torch.max(w,dim=(2))[0].unsqueeze(-1))) # B x A x C

# intermediate output 
# (indicates the location of the  maximum probabilities)
>> tensor([[[0, 1, 0],
            [1, 0, 0],
            [0, 0, 1],
            [1, 0, 0]],
           [[0, 1, 0],
            [1, 0, 0],
            [0, 0, 1],
            [0, 1, 0]]])

# second step 
# (converts the intermediate output vectors to the predicted class labels)
wp = torch.max(wa,dim=2)[1] 

# class number representation, shape: B x A (with dimension C dropped)
>> tensor([[1, 0, 2, 0],
           [1, 0, 2, 1]])

I would like to achieve that my final labels that are constructed by further processing wp have a gradient with respect to w. This will be the case if wp has a gradient wrt w. However, the second operation drops the gradient, because it asks for the max_indices (list element [1]) of wa, which are not differentiable, I think. Or in other words, the second operation is an argmax operation, which does not have a gradient, afaic.

Is there any way to rewrite the operations (i.e some smart way to circumvent the argmax operation) such that the gradient of wp wrt w is somehow preserved? I did not find one, yet.


Best, JZ

Hi Jay!

The short story is that the gradient you are asking for isn’t useful because
it is zero almost everywhere.

Some more detail:

The problem is that your labels, wp, are integers (and would have integer
values even if you were to convert them to floating-point numbers). It
doesn’t really make sense to differentiate integers.

Consider wp[0, 0]. In the example you give it starts out as 1. As you
vary the values in w, wp[0, 0] will stay 1 for a while, and over this
range the gradient will be zero (as wp[0, 0] is not changing).

Then, at some point as you vary w, wp[0, 0] will “pop” to another
value (say to 2). Right at this point the gradient will be infinity (or,
if you prefer, undefined).

That is, for almost all values of w, the gradient will be zero and will
otherwise be infinite (or undefined).

(Because gradients of things like argmax() aren’t useful, pytorch
generally doesn’t compute them for you.)

As an aside, it sounds like you want to use as a loss the cross entropy
of some floating-point numbers (w) with (integer) class labels (wp)
derived from those floating-point numbers. You will have useful gradients
of your loss with respect to w (but not wp).

I don’t understand your use case, so I can’t say whether gradients
directly with respect to w would make sense for you, but they would
be conceptually valid, and pytorch will compute them for you.


K. Frank

1 Like

Hello K. Frank,

thanks for your reply!
Right, I totally forgot that the integer nature of the labels prevents a creation of the gradient, on top of using the argmax. I’ll try reformulating my loss as a continuous one and transform the problem to a regression instead of a classification setting. Then, it should be possible to get the derivative of the labels wrt the input. Also, yes, the gradient of the loss wrt w is there, but for my application its crucial to have one wrt the labels, as well.

Best, JZ