Torch.argmax() cause loss.backward() don't work

I am now training a network with Pytorch. In the training, I use 64 bins to predict the dimensions.
I find I can’t make it. My loss.backward() can’t work!

            output_coor_x_ = output_coor_x_.squeeze()
            output_coor_y_ = output_coor_y_.squeeze()
            output_coor_z_ = output_coor_z_.squeeze()

           ####

            output_coor_ = torch.stack([torch.argmax(output_coor_x_, axis=0),
                                     torch.argmax(output_coor_y_, axis=0),
                                     torch.argmax(output_coor_z_, axis=0)], axis=2)
            output_coor_[output_coor_ == cfg.network.coor_bin] = 0
            output_coor_ = 2.0 * output_coor_.float() / (63.0) - 1.0      # [-1,1]
            output_coor_[:, :, 0] = output_coor_[:, :, 0] * abs(model_min_x)

I try to use variable with before ‘####’, loss.backward() work; but when I use variable after ‘####’, loss.backward() not work and output as follow:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I am sure that parameters’s requires_grad=True in the model.

Maybe I shouldn’t use torch.argmax() ?

1 Like

Hi Kiruto!

In short, yes, don’t use argmax() in your loss function.

loss.backward() computes the gradient (derivative) of loss
with respect to your model parameters. (The optimizer then uses
a gradient-descent algorithm to adjust those parameters to make
the loss smaller.)

For this to work, your loss function has to be differentiable.
argmax() isn’t differentiable, so you can’t use it as part of
your loss.

I don’t follow what you are doing in your code, so I don’t know
specifically how you might fix it. But as a general comment,
you might investigate whether a differentiable approximation
to argmax() would still make sense in your loss function.

softmax() (which should more accurately be called softargmax())
can be understood to be a differentiable approximation to argmax()
(much in the same way that sigmoid() is a differentiable
approximation to a step function).

If you can reformulate your loss function to use softmax(), you
can make it differentiable and loss.backward() should work.

Good luck.

K. Frank

1 Like