# Interesting argmax differential

I have two code blocks with argmax in both.
First,

``````x = torch.tensor([0.1, 0.2], dtype=torch.float32, requires_grad=True)
y = torch.argmax(x, dim=-1)
y = torch.unsqueeze(y, 0)
y = y.to(torch.float32)
y = nn.Linear(1, 1)(y)
y.backward()
``````

And second,

``````x = torch.tensor([0.1, 0.2], dtype=torch.float32, requires_grad=True)
y = torch.argmax(x, dim=-1)
y.backward()
``````

Interesting thing is, the first code can run successful, but the second failed. So why?

Hi Trick!

As a side note, at this point `y` is ` long` so it canâ€™t have
`requires_grad = True`. So you wouldnâ€™t be able to run
`.backward()` on it.

Here, you make `y` a `float` (but it still doesnâ€™t have
`requires_grad = True`). Then when you run it through `Linear`,
it gets multiplied my `Linear`'s `weight` which does have
`requires_grad = True` (and gets `bias` added to it which also has
`requires_grad = True`). So the result has `requires_grad = True`,
and you can run `.backward()` on it.

Here, the fact that `x` has `requires_grad = True` is irrelevant. The
result of `argmax()` is `long` (which canâ€™t have `requires_grad = True`),
so you canâ€™t run `.backward()` on it.

Best.

K. Frank

1 Like

Hi,KFrank,
Thanks for your reply. I am still confused. In the first code, when `y.backward()` is called, I think that the backward operation will be run through the first two lines , i.e. `dy/dx` will be calculated, but `argmax` is not differentiable, so it should cause an exception. But this seems to be wrong by your explanations.
According to your explanations, the backward operation will not be run through the first two lines, because y does not have requires_grad, right? So the question is, in Pytorch, the `dy/dx` will be set to ZERO like differential of ReLU at point zero, or simply the backward operation has been stopped at earlier, for instance at line 5?

Hi Trick!

This does not happen because, as you note:

so it â€śbreaks the computation graphâ€ť and the `.backward()` processing
does not flow back to the first two lines so no attempt to calculate `dy/dy`

As an aside, if there were no path backward through the computation
graph, `.backward()` would raise and exception. However (see below),
there is another path backwards to a leaf with `requires_grad = True`,
so, in your example, no exception is raised.

Correct.

The part of the backward operation that is trying to flow back to `x` is
stopped (although I would say that it is stopped at line 4). So `dy/dx`
is not set to zero, but, rather, is never created. More precisely, `x.grad`
is never created. And, just to say it explicitly, this is different than
creating `x.grad` and having it be zero.

Now, as for there being another backward path, please note that
`Linear` contains tensors (`weight` and `bias`) that have
`requires_grad = True`. So, even though `.backward()` does not
create and calculate a `.grad` for `x`, it does for the tensors in `Linear`.

(And because there are some `.grad`s that are being calculated when
you run `.backward()`, no exception is raised.)

Here is a short example that illustrates this:

``````>>> import torch
>>> torch.__version__
'1.7.1'
>>> x = torch.tensor([0.1, 0.2], dtype=torch.float32, requires_grad=True)
>>> y = torch.argmax(x, dim=-1)
>>> y = torch.unsqueeze(y, 0)
>>> y = y.to(torch.float32)
>>> my_linear = torch.nn.Linear(1, 1)
>>> y = my_linear (y)
True
True
>>> y.backward()
Notice that `x.grad` never gets created, and that `my_linear.weight.grad`
only gets created (and calculated) when `y.backward()` is called.