Gradient of `maximum` and `minimum` functions


This is regarding the behavior of torch.maximum and torch.minimum functions.
Here is an example:
Let a be and scalar.
Currently when computing torch.maximum(x, a), if x > a then the gradient is 1, and if x < a then the gradient is 0. BUT if x = a then the gradient is 0.5.
The same is true for torch.minimum.
Are the mathematical reasons for the 0.5 gradient when x = a? or is it for numerical stability issues?


I think this is expected behavior as described here and implemented here.
CC @albanD to correct me

Ok, got it. So the reason for implementing them this way is for practical reasons rather than mathematical logic.

The mathematical logic is that, at this point, the function is not differentiable. but you can define some sub-diffentials. In this case, the convex hull of [(0,1), (1, 0)]. And to ensure we get a descent direction, we take the element in this set with the minimum norm: (0.5, 0.5).

Does that help?

1 Like

Yes, makes sense, thanks :slight_smile: