Torch.argmax() cause loss.backward() don't work

I am now training a network with Pytorch. In the training, I use 64 bins to predict the dimensions.
I find I can’t make it. My loss.backward() can’t work!

            output_coor_x_ = output_coor_x_.squeeze()
            output_coor_y_ = output_coor_y_.squeeze()
            output_coor_z_ = output_coor_z_.squeeze()

           ####

            output_coor_ = torch.stack([torch.argmax(output_coor_x_, axis=0),
                                     torch.argmax(output_coor_y_, axis=0),
                                     torch.argmax(output_coor_z_, axis=0)], axis=2)
            output_coor_[output_coor_ == cfg.network.coor_bin] = 0
            output_coor_ = 2.0 * output_coor_.float() / (63.0) - 1.0      # [-1,1]
            output_coor_[:, :, 0] = output_coor_[:, :, 0] * abs(model_min_x)

I try to use variable with before ‘####’, loss.backward() work; but when I use variable after ‘####’, loss.backward() not work and output as follow:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I am sure that parameters’s requires_grad=True in the model.

Maybe I shouldn’t use torch.argmax() ?

1 Like

Hi Kiruto!

In short, yes, don’t use argmax() in your loss function.

loss.backward() computes the gradient (derivative) of loss
with respect to your model parameters. (The optimizer then uses
a gradient-descent algorithm to adjust those parameters to make
the loss smaller.)

For this to work, your loss function has to be differentiable.
argmax() isn’t differentiable, so you can’t use it as part of
your loss.

I don’t follow what you are doing in your code, so I don’t know
specifically how you might fix it. But as a general comment,
you might investigate whether a differentiable approximation
to argmax() would still make sense in your loss function.

softmax() (which should more accurately be called softargmax())
can be understood to be a differentiable approximation to argmax()
(much in the same way that sigmoid() is a differentiable
approximation to a step function).

If you can reformulate your loss function to use softmax(), you
can make it differentiable and loss.backward() should work.

Good luck.

K. Frank

2 Likes

@KFrank @ptrblck. Hello. I have a question in relation to using argmax.
I am doing the following :

  1. A model is used as an encoder.
  2. The model predicts some outputs which I then take and convert into a numpy array. (so I also detach them first from the tensor.)
  3. During the course of subsequent calculations on this numpy array , I use a argmax() to finally return me something ( for example something like [[1,4,6,3]]. Lets call them predictions.)
  4. I convert these predictions to tensors and also set the requires_grad paramter to True as:

torch.tensor(predictions, requires_grad=True)

  1. I then try and apply the torch.nn.L1Loss on this and my ground truth as :

loss=torch.nn.L1Loss(predictions,ground_truth)

When I run the code, I see the same loss for all the epochs and the model basically does not train.
For example:

Epoch : 1, Loss : 4.0
Epoch : 2, Loss : 4.0
Epoch : 3, Loss : 4.0
Epoch : 4, Loss : 4.0

Now , I know that the argmax() is not differentiable and hence should not be used in the loss. However, I try and do some post processing on my outputs (that involves me doing the argmax() operation) before converting it to a tensor and then apply the CrossEntropy loss to it.
If it is still wrong, it there an alternative to using the argmax()?
Could you please help me with this??

Hi Sourabh!

Please start a new thread and post your question there. It’s preferred
to resurrecting zombie threads.

Good luck!

K. Frank