I have two code blocks with argmax in both.
First,

x = torch.tensor([0.1, 0.2], dtype=torch.float32, requires_grad=True)
y = torch.argmax(x, dim=-1)
y = torch.unsqueeze(y, 0)
y = y.to(torch.float32)
y = nn.Linear(1, 1)(y)
y.backward()

And second,

x = torch.tensor([0.1, 0.2], dtype=torch.float32, requires_grad=True)
y = torch.argmax(x, dim=-1)
y.backward()

Interesting thing is, the first code can run successful, but the second failed. So why?

As a side note, at this point y is long so it canâ€™t have requires_grad = True. So you wouldnâ€™t be able to run .backward() on it.

Here, you make y a float (but it still doesnâ€™t have requires_grad = True). Then when you run it through Linear,
it gets multiplied my Linear's weight which does have requires_grad = True (and gets bias added to it which also has requires_grad = True). So the result has requires_grad = True,
and you can run .backward() on it.

Here, the fact that x has requires_grad = True is irrelevant. The
result of argmax() is long (which canâ€™t have requires_grad = True),
so you canâ€™t run .backward() on it.

Hi,KFrank,
Thanks for your reply. I am still confused. In the first code, when y.backward() is called, I think that the backward operation will be run through the first two lines , i.e. dy/dx will be calculated, but argmax is not differentiable, so it should cause an exception. But this seems to be wrong by your explanations.
According to your explanations, the backward operation will not be run through the first two lines, because y does not have requires_grad, right? So the question is, in Pytorch, the dy/dx will be set to ZERO like differential of ReLU at point zero, or simply the backward operation has been stopped at earlier, for instance at line 5?

so it â€śbreaks the computation graphâ€ť and the .backward() processing
does not flow back to the first two lines so no attempt to calculate dy/dy
is made.

As an aside, if there were no path backward through the computation
graph, .backward() would raise and exception. However (see below),
there is another path backwards to a leaf with requires_grad = True,
so, in your example, no exception is raised.

Correct.

The part of the backward operation that is trying to flow back to x is
stopped (although I would say that it is stopped at line 4). So dy/dx
is not set to zero, but, rather, is never created. More precisely, x.grad
is never created. And, just to say it explicitly, this is different than
creating x.grad and having it be zero.

Now, as for there being another backward path, please note that Linear contains tensors (weight and bias) that have requires_grad = True. So, even though .backward() does not
create and calculate a .grad for x, it does for the tensors in Linear.

(And because there are some .grads that are being calculated when
you run .backward(), no exception is raised.)