Hey,
I have a model that outputs batches with size B of “A-dimensional” softmax probabilities for C classes. For instance, the output w
could look like this, for B=2, A=4 and C=3:
# B x A x C
w = torch.softmax(torch.rand([2,4,3]),dim=2)
w
# standard A-dimensional class probabilities, shape: B x A x C
>> tensor([[[0.2901, 0.4441, 0.2658],
[0.4681, 0.2511, 0.2808],
[0.2814, 0.3119, 0.4067],
[0.4007, 0.3358, 0.2635]]
[[0.2750, 0.4774, 0.2476],
[0.3758, 0.3680, 0.2562],
[0.3018, 0.3007, 0.3976],
[0.3610, 0.3699, 0.2691]]])
where the probs sum up along the class dimension.
In my loss function, I need to construct the labels / targets y
in dependence of the input w
, i.e. y = y(w)
, before feeding them together into nn.CrossEntropyLoss()(w.transpose(1,2),y)
. In order to do that, I need to convert w
to a representation that shows the class numbers instead of the softmax probabilities. I do it like this:
# first step
wa = (1*(w == torch.max(w,dim=(2))[0].unsqueeze(-1))) # B x A x C
wa
# intermediate output
# (indicates the location of the maximum probabilities)
>> tensor([[[0, 1, 0],
[1, 0, 0],
[0, 0, 1],
[1, 0, 0]],
[[0, 1, 0],
[1, 0, 0],
[0, 0, 1],
[0, 1, 0]]])
# second step
# (converts the intermediate output vectors to the predicted class labels)
wp = torch.max(wa,dim=2)[1]
wp
# class number representation, shape: B x A (with dimension C dropped)
>> tensor([[1, 0, 2, 0],
[1, 0, 2, 1]])
I would like to achieve that my final labels that are constructed by further processing wp
have a gradient with respect to w
. This will be the case if wp
has a gradient wrt w
. However, the second operation drops the gradient, because it asks for the max_indices (list element [1]) of wa
, which are not differentiable, I think. Or in other words, the second operation is an argmax operation, which does not have a gradient, afaic.
Is there any way to rewrite the operations (i.e some smart way to circumvent the argmax operation) such that the gradient of wp
wrt w
is somehow preserved? I did not find one, yet.
Thanks!
Best, JZ