The below example in theory does the same thing, but gives different gradients.
If multiple elements in the input have the minimal value, should gradients flow back through both or a single one
If gradients should only flow through one of the min elements, how do we determine which? (EDIT: The docs here say that the first element will be selected: torch.min — PyTorch 2.6 documentation)
In this example, the gradient is split between both min elements in an all-reduce, but only one (non-determinstically?) is selected when reducing over a dim.
a = torch.tensor([0.1, 0.3, 0.1], dtype=torch.float32, requires_grad = True)
a_cp = torch.tensor([0.1, 0.3, 0.1], dtype=torch.float32, requires_grad = True)
b = a.min()
b.backward()
print(a.grad) # Output is tensor([0.5000, 0.0000, 0.5000])
c, d = a_cp.min(dim=0)
c.backward()
print(a_cp.grad) # Output is tensor([1., 0., 0.])
I see the same behavior as you on pytorch version 2.6.0. I do see this as a (minor)
bug.
I read the github issue you linked to as follows
This is a legitimate issue.
The torch.min (dim = 0) behavior is “correct.”
The issue was closed as “completed,” which presumably means fixed.
I have no idea whether is was actually fixed, but if it was, the current behavior is a
regression. (Note, I haven’t tried to verify whether the current, “correct” min (dim = 0)
behavior is deterministic or not.)
This function produces deterministic (sub)gradients unlike min(dim=0)
This agrees neither with the actual behavior – subgradients are not produced for min() – nor with the desired behavior as indicated by the github issue – subgradients are produced for min (dim = 0).
I agree with you that this is a real, if minor, bug and ideally it should be fixed. At a
minimum, the documentation should be brought into agreement with reality.
Thanks for your detailed reply! I haven’t actually come across a usecase where this behavior has caused me problems. I’m writing my own “torch lite” library and only noticed this when writing my own implementation, and what should be the intended behaviour was thus confusing.
As mentioned above, there is a github issue tracking this behaviour for those curious to follow!
There is no regression and the issue was fixed correctly. For non-differentiable functions we return lowest norm subgradient, which, in case of min with no dimension, would mean sending part of the gradient to all the min elements. min with a dimension returns both an index of (one of the) minimum elements and its value, thus, due to semantics of this function, the gradient will have to go to the index returned by the operation. If you want the same behavior as min with no dimension, you should use amin(dim=0) which propagates gradient to all the minimum value elements.
This is a bit out of depth of understanding, but are you able to shed some light on why gradients would propoxate to single elements vs split to all matching? Is the split gradient the proper default when working with these reduction operations, but because semantically .min also ties along with a specific index, that restricts to where the gradient should flow back to?
Correct, split gradient is minimum norm subgradient, which is what we generally want, but for min operation with index we should propagate gradient only to this single index.
This question is not to imply that a correct fix hasn’t been put into place.
Am I missing some other (more relevant) github issue? My reading of the discussion
in github issue 35699 (which Jake cited) is pretty clearly that the min (dim) behavior
(namely, that only one element gets the gradient) is correct and that min() would be
made consistent with min (dim). I haven’t looked at the PRs, but the issue is then
“closed as completed,” so pytorch’s current behavior is out of sync with that particular
github discussion.
As for the documentation, I find it to be so terse as to be ambiguous. I suppose it
can be read as being consistent with pytorch’s current behavior, but that’s not
very helpful.
I believe I understand the logic – there are two competing concerns. Because min (dim) returns an index, you want to populate just the gradient for that index.
On the other hand, because min() does not return an index you prefer to spread
the gradient across multiple elements that have the minimum value in order to get
the norm-wise smallest subgradient (but at the cost of being inconsistent with min (dim)).
To me, consistency seems to be the more important consideration. Can you give
a real-world use case where I would want, or be surprised by not getting, the
norm-wise minimum subgradient? (As it stands, I am surprised by not getting the
behavior of min (dim).)
In #35699 prior to the fix the gradient for all minimum position was 1, so it was not a subgradient and that was clearly wrong. It was fixed to return the minimum norm subgradient.
As for consistency concern, the inconsistency is in the naming of 2 conceptually different functions - one returning indices and another not returning indices - as min. Unfortunately, we cannot change this due to bc concerns, and it would require a lot of churn in the users code for no practical gain. The real counterpart of min() with no dimension is amin(dim=...), and the gradient behavior of these functions is consistent.
For a discussion on the choice of minimum norm subgradient see Wrong gradient for torch.norm(x, p=float('inf')) when input tensor has non-unique max values · Issue #41779 · pytorch/pytorch · GitHub
One last question. I noticed that in a max_pool2d, the gradient follows max(dim=0) where the entire gradient flows back to a single matching element. From the logic above, I would have assumed that the gradient should be split between all elements. From reading your linked discussion, it seems like the decision was made to follow the behaviour of the pooling ops, but why did the pooling ops decide on not splitting the gradient?
Consider the following:
x = torch.ones((1, 1, 2, 2), requires_grad=True)
out = torch.nn.functional.max_pool2d(x, kernel_size=2,stride=1, padding=0)
grad_out = torch.ones_like(out)
out.backward(grad_out)
print(x.grad)
Thank you for this explanation – this makes a lot more sense now.
There was a real bug (real in the sense of being mathematically wrong, above and
beyond any consistency concerns). This bug was fixed and the github issue was
closed.
That the bug was fixed with the minimum-norm-subgradient approach rather than the
consistency-with-min (dim) approach is a separate issue and the minimum-norm
approach is certainly mathematically legitimate.
Yeah max_pool behavior deviates from minimum norm subgradient, so it’s a bug, but tbh I don’t think there’s a lot of appetite for fixing it. It returns a subgradient, just not the minimum norm subgradient.
Thanks for confirming. I agree that its not a high priority issue but I wanted to make sure I understand what should the desirable behaviour be with respect to subgradients.