Given a tensor of values in the range [0, 1], multiplying these values with a scalar `p` and applying a softmax gives scaled probabilities that sum to 1. Increasing `p` pushes the values to either 0 or 1. See example:

``````values = torch.tensor([0.00, 0.25, 0.50, 0.75, 1.00])

softmax_values = torch.stack([torch.softmax(p*values, 0) for p in np.arange(0, 20)])

for input, outputs in zip(values, softmax_values.T):
c = plt.gca()._get_lines.get_next_color()
plt.plot([0], [input], 'o', c=c)
plt.plot(outputs, c=c)

plt.xlabel("p")

plt.show()
``````

Dots represent the input values and lines show softmax values of those values given a certain `p`.

Setting low values of `p` results in all softmax values converge to 1 / n_values = 1 / 5 = 0.2.

How do I instead create a function that gradually transform the input values to their softmax as p → inf?

That function would look like:

Hi Zimo!

There are lots of possibilities. (You could choose between them by
imposing further conditions on the behavior of your function.)

Here is an illustration of one possible choice:

``````>>> import torch
>>> torch.__version__
'1.7.1'
>>> _ = torch.manual_seed (2021)
...     return torch.where (values == values.max(), 1.0 - (1.0 - values)**p, values**p)
...
>>> values = torch.rand (5)
>>> values
tensor([0.1304, 0.5134, 0.7426, 0.7159, 0.5705])
tensor([0.0170, 0.2636, 0.9337, 0.5125, 0.3254])
tensor([0.0022, 0.1353, 0.9829, 0.3669, 0.1857])
tensor([3.7761e-05, 3.5664e-02, 9.9887e-01, 1.8803e-01, 6.0419e-02])
tensor([1.4259e-09, 1.2719e-03, 1.0000e+00, 3.5355e-02, 3.6504e-03])
``````

Best.

K. Frank

Thanks for the answer K. Frank.

I will try to tinker with this solution in a few hours.

Do you happen to know if `torch.where` breaks the gradent tree? I am looking for a differentiable function even if I know that for large values of `p`, the gradient will be strangled.

Hi Zimo!

`torch.where()` does not break the computation graph – that’s a good

But, furthermore, in this specific case, each individual element of `values`
only participates in a single the branch of the `torch.where()` function.
I’m just using it as a more compact way to single out the largest element
in `values`, rather than indexing with `argmax()`.

Best.

K. Frank

Yes it seems to work quite nicely

``````values = torch.tensor([0.20, 0.40, 0.60, 0.80], requires_grad=True)

softmax_values = torch.stack([gradual(values, p) for p in np.arange(1, 20)])

for input, outputs in zip(values, softmax_values.T):
c = plt.gca()._get_lines.get_next_color()
plt.plot([0], [input], 'o', c=c)
plt.plot(outputs.detach(), c=c)

plt.xlabel("p")
plt.show()

loss = softmax_values.mean()
loss.backward()
`tensor([0.0206, 0.0365, 0.0822, 0.0206])`
Do you know if that is the case generally, that using `torch.where(x == max(x), t1, t2)` is differentiable while `torch.argmax(x)` is not? I guess they have different use-cases but this feels like a very nice “trick” to maintain gradient.