Interesting argmax differential

I have two code blocks with argmax in both.
First,

x = torch.tensor([0.1, 0.2], dtype=torch.float32, requires_grad=True)
y = torch.argmax(x, dim=-1)
y = torch.unsqueeze(y, 0)
y = y.to(torch.float32)
y = nn.Linear(1, 1)(y)
y.backward()

And second,

x = torch.tensor([0.1, 0.2], dtype=torch.float32, requires_grad=True)
y = torch.argmax(x, dim=-1)
y.backward()

Interesting thing is, the first code can run successful, but the second failed. So why?

Hi Trick!

As a side note, at this point y is long so it can’t have
requires_grad = True. So you wouldn’t be able to run
.backward() on it.

Here, you make y a float (but it still doesn’t have
requires_grad = True). Then when you run it through Linear,
it gets multiplied my Linear's weight which does have
requires_grad = True (and gets bias added to it which also has
requires_grad = True). So the result has requires_grad = True,
and you can run .backward() on it.

Here, the fact that x has requires_grad = True is irrelevant. The
result of argmax() is long (which can’t have requires_grad = True),
so you can’t run .backward() on it.

Best.

K. Frank

1 Like

Hi,KFrank,
Thanks for your reply. I am still confused. In the first code, when y.backward() is called, I think that the backward operation will be run through the first two lines , i.e. dy/dx will be calculated, but argmax is not differentiable, so it should cause an exception. But this seems to be wrong by your explanations.
According to your explanations, the backward operation will not be run through the first two lines, because y does not have requires_grad, right? So the question is, in Pytorch, the dy/dx will be set to ZERO like differential of ReLU at point zero, or simply the backward operation has been stopped at earlier, for instance at line 5?

Hi Trick!

This does not happen because, as you note:

so it “breaks the computation graph” and the .backward() processing
does not flow back to the first two lines so no attempt to calculate dy/dy
is made.

As an aside, if there were no path backward through the computation
graph, .backward() would raise and exception. However (see below),
there is another path backwards to a leaf with requires_grad = True,
so, in your example, no exception is raised.

Correct.

The part of the backward operation that is trying to flow back to x is
stopped (although I would say that it is stopped at line 4). So dy/dx
is not set to zero, but, rather, is never created. More precisely, x.grad
is never created. And, just to say it explicitly, this is different than
creating x.grad and having it be zero.

Now, as for there being another backward path, please note that
Linear contains tensors (weight and bias) that have
requires_grad = True. So, even though .backward() does not
create and calculate a .grad for x, it does for the tensors in Linear.

(And because there are some .grads that are being calculated when
you run .backward(), no exception is raised.)

Here is a short example that illustrates this:

>>> import torch
>>> torch.__version__
'1.7.1'
>>> x = torch.tensor([0.1, 0.2], dtype=torch.float32, requires_grad=True)
>>> y = torch.argmax(x, dim=-1)
>>> y = torch.unsqueeze(y, 0)
>>> y = y.to(torch.float32)
>>> my_linear = torch.nn.Linear(1, 1)
>>> y = my_linear (y)
>>> x.requires_grad
True
>>> x.grad
>>> my_linear.weight.requires_grad
True
>>> my_linear.weight.grad
>>> y.backward()
>>> x.grad
>>> my_linear.weight.grad
tensor([[1.]])

Notice that x.grad never gets created, and that my_linear.weight.grad
only gets created (and calculated) when y.backward() is called.

Best.

K. Frank

Thanks, K. Frank. Now I’ve hot it