How to implement the exactly same softmax as F.softmax by pytorch

Hi everyone,

Recently I need to re-implement the softmax function to design my own softmax. I refer the codes on the Github and implemented one as shown below.

def own_softmax(self, x)
    
    maxes = torch.max(x, 1, keepdim=True)[0]
    x_exp = torch.exp(x-maxes)
    x_exp_sum = torch.sum(x_exp, 1, keepdim=True)

    return x_exp/x_exp_sum

However, after implementation I found that the results are not as good as the original one (F.softmax). So I am here to ask what is the difference between my implementation and the built-in function.

Thank you so much!

Anyone can help here?

I have tried some other implementation like the following one,

def own_softmax(self, x)
    
    means = torch.mean(x, 1, keepdim=True)[0]
    x_exp = torch.exp(x-means)
    x_exp_sum = torch.sum(x_exp, 1, keepdim=True)

    return x_exp/x_exp_sum

and found that this implementation can achieve better accuracy. However, it is still not as good as the F.softmax. Anyone can help?

:sob: Anyone can help?

so sad, so sad … :pleading_face:

Your custom function returns the same output as F.softmax:

x = torch.randn(5, 10)
output = F.softmax(x, 1)

maxes = torch.max(x, 1, keepdim=True)[0]
x_exp = torch.exp(x-maxes)
x_exp_sum = torch.sum(x_exp, 1, keepdim=True)
output_custom = x_exp/x_exp_sum

print(torch.allclose(output, output_custom))
> True
print(torch.sum(torch.abs(output-output_custom)))
> tensor(2.3108e-7)
2 Likes

The output from your own_softmax is slightly different from torch.nn.functional.softmax .
This may be the reason why your own_softmax degrades the performance.

x = torch.randn(2,10)
h_own = own_softmax(x)
h = torch.nn.functional.softmax(x, 1)
print(h - h_own)
1 Like

Thank you so much. It seems that different centralization method for the network output score influence the softmax output a lot. I have tested the centralization using max value and mean value, and their output are quite different. I am wondering whether the mean one is more stable?

1 Like

Thank you so much for your reply.

If I swap torch.max for torch.mean, I get approx. the same accuracy.
Could you post the code you’ve used to check the behavior?