Is clamp on torch.exp is a good alternative to softmax

torch.softmax

is stable to work on some large data. For example,

x = torch.tensor([1., 2, 150])

F.softmax(x, dim = 0)
tensor([0., 0., 1.])

I actually have to manually calculated the softmax where I can not directly use softmax function.
I am doing:

torch.exp(x)/torch.exp(x).sum()
tensor([0., 0., nan])

This would further leading to nan in loss.
Thus, I choose to clip the input before the exponential function.

def clamp_exp(x):
    
    return torch.exp(torch.clamp(x, max = 20))

clamp_exp(x)/clamp_exp(x).sum()
tensor([5.6028e-09, 1.5230e-08, 1.0000e+00])

Since I am not quite familiar with the clamp function, I am wondering is there any good way to avoiding exploiding, will the clamp affect the backpropagation?

Hi Came!

First, you should modify the structure of your overall computation so that
you can use log_softmax() for this step. Whether you write your own
version or use pytorch’s, log_softmax() will almost certainly lead to a
numerically more stable computation as it is much less likely to “explode”
or underflow to zero.

Also, you should figure out how to structure your computation to use
pytorch’s version (whether it be log_softmax() or softmax()), rather
than reinvent the wheel by writing your own. Why do you need to write
your own version? If worst were to come to worst, couldn’t you just
implement “your own version” by writing a wrapper around pytorch’s?

Best.

K. Frank

1 Like

Hi, Frank,
I appreciate your kind replies.
I was trying to implement some neighborhood attention (like Graph Attention Network) where softmax is not directly be able to applied. You have to use scatter_add ( segment sum in tensor flow) and calculate the denominator manually.

I found folks try to subtract the maximum value to make exponential stable:
https://github.com/gordicaleksa/pytorch-GAT/blob/39c8f0ee634477033e8b1a6e9a6da3c7ed71bbd1/models/definitions/GAT.py#L262