Gradual softmax?

Given a tensor of values in the range [0, 1], multiplying these values with a scalar p and applying a softmax gives scaled probabilities that sum to 1. Increasing p pushes the values to either 0 or 1. See example:

values = torch.tensor([0.00, 0.25, 0.50, 0.75, 1.00])

softmax_values = torch.stack([torch.softmax(p*values, 0) for p in np.arange(0, 20)])

for input, outputs in zip(values, softmax_values.T):
    c = plt.gca()._get_lines.get_next_color()
    plt.plot([0], [input], 'o', c=c)
    plt.plot(outputs, c=c)

plt.xlabel("p")
    
plt.show()

download (35)

Dots represent the input values and lines show softmax values of those values given a certain p.

Setting low values of p results in all softmax values converge to 1 / n_values = 1 / 5 = 0.2.

How do I instead create a function that gradually transform the input values to their softmax as p → inf?

That function would look like:

download (36)

Hi Zimo!

There are lots of possibilities. (You could choose between them by
imposing further conditions on the behavior of your function.)

Here is an illustration of one possible choice:

>>> import torch
>>> torch.__version__
'1.7.1'
>>> _ = torch.manual_seed (2021)
>>> def gradual (values, p):
...     return torch.where (values == values.max(), 1.0 - (1.0 - values)**p, values**p)
...
>>> values = torch.rand (5)
>>> values
tensor([0.1304, 0.5134, 0.7426, 0.7159, 0.5705])
>>> gradual (values, 1)
tensor([0.1304, 0.5134, 0.7426, 0.7159, 0.5705])
>>> gradual (values, 2)
tensor([0.0170, 0.2636, 0.9337, 0.5125, 0.3254])
>>> gradual (values, 3)
tensor([0.0022, 0.1353, 0.9829, 0.3669, 0.1857])
>>> gradual (values, 5)
tensor([3.7761e-05, 3.5664e-02, 9.9887e-01, 1.8803e-01, 6.0419e-02])
>>> gradual (values, 10)
tensor([1.4259e-09, 1.2719e-03, 1.0000e+00, 3.5355e-02, 3.6504e-03])

Best.

K. Frank

Thanks for the answer K. Frank.

I will try to tinker with this solution in a few hours.

Do you happen to know if torch.where breaks the gradent tree? I am looking for a differentiable function even if I know that for large values of p, the gradient will be strangled.

Hi Zimo!

torch.where() does not break the computation graph – that’s a good
thing about it.

But, furthermore, in this specific case, each individual element of values
only participates in a single the branch of the torch.where() function.
I’m just using it as a more compact way to single out the largest element
in values, rather than indexing with argmax().

Best.

K. Frank

Yes it seems to work quite nicely :slight_smile:

values = torch.tensor([0.20, 0.40, 0.60, 0.80], requires_grad=True)

def gradual (values, p):
    return torch.where (values == values.max(), 1.0 - (1.0 - values)**p, values**p)


softmax_values = torch.stack([gradual(values, p) for p in np.arange(1, 20)])

for input, outputs in zip(values, softmax_values.T):
    c = plt.gca()._get_lines.get_next_color()
    plt.plot([0], [input], 'o', c=c)
    plt.plot(outputs.detach(), c=c)
    
plt.xlabel("p")
plt.show()

loss = softmax_values.mean()
loss.backward()
values.grad

download (37)

tensor([0.0206, 0.0365, 0.0822, 0.0206])

Do you know if that is the case generally, that using torch.where(x == max(x), t1, t2) is differentiable while torch.argmax(x) is not? I guess they have different use-cases but this feels like a very nice “trick” to maintain gradient.