Hey,

I have a model that outputs batches with size B of “A-dimensional” softmax probabilities for C classes. For instance, the output `w`

could look like this, for B=2, A=4 and C=3:

```
# B x A x C
w = torch.softmax(torch.rand([2,4,3]),dim=2)
w
# standard A-dimensional class probabilities, shape: B x A x C
>> tensor([[[0.2901, 0.4441, 0.2658],
[0.4681, 0.2511, 0.2808],
[0.2814, 0.3119, 0.4067],
[0.4007, 0.3358, 0.2635]]
[[0.2750, 0.4774, 0.2476],
[0.3758, 0.3680, 0.2562],
[0.3018, 0.3007, 0.3976],
[0.3610, 0.3699, 0.2691]]])
```

where the probs sum up along the class dimension.

In my loss function, I need to construct the labels / targets `y`

in dependence of the input `w`

, i.e. `y = y(w)`

, before feeding them together into `nn.CrossEntropyLoss()(w.transpose(1,2),y)`

. In order to do that, I need to convert `w`

to a representation that shows the class numbers instead of the softmax probabilities. I do it like this:

```
# first step
wa = (1*(w == torch.max(w,dim=(2))[0].unsqueeze(-1))) # B x A x C
wa
# intermediate output
# (indicates the location of the maximum probabilities)
>> tensor([[[0, 1, 0],
[1, 0, 0],
[0, 0, 1],
[1, 0, 0]],
[[0, 1, 0],
[1, 0, 0],
[0, 0, 1],
[0, 1, 0]]])
# second step
# (converts the intermediate output vectors to the predicted class labels)
wp = torch.max(wa,dim=2)[1]
wp
# class number representation, shape: B x A (with dimension C dropped)
>> tensor([[1, 0, 2, 0],
[1, 0, 2, 1]])
```

I would like to achieve that my final labels that are constructed by further processing `wp`

have a gradient with respect to `w`

. This will be the case if `wp`

has a gradient wrt `w`

. However, the second operation drops the gradient, because it asks for the max_indices (list element [1]) of `wa`

, which are not differentiable, I think. Or in other words, the second operation is an argmax operation, which does not have a gradient, afaic.

Is there any way to rewrite the operations (i.e some smart way to circumvent the argmax operation) such that the gradient of `wp`

wrt `w`

is somehow preserved? I did not find one, yet.

Thanks!

Best, JZ