# How to implement the exactly same softmax as F.softmax by pytorch

Hi everyone,

Recently I need to re-implement the softmax function to design my own softmax. I refer the codes on the Github and implemented one as shown below.

``````def own_softmax(self, x)

maxes = torch.max(x, 1, keepdim=True)
x_exp = torch.exp(x-maxes)
x_exp_sum = torch.sum(x_exp, 1, keepdim=True)

return x_exp/x_exp_sum
``````

However, after implementation I found that the results are not as good as the original one (F.softmax). So I am here to ask what is the difference between my implementation and the built-in function.

Thank you so much!

Anyone can help here?

I have tried some other implementation like the following one,

``````def own_softmax(self, x)

means = torch.mean(x, 1, keepdim=True)
x_exp = torch.exp(x-means)
x_exp_sum = torch.sum(x_exp, 1, keepdim=True)

return x_exp/x_exp_sum
``````

and found that this implementation can achieve better accuracy. However, it is still not as good as the F.softmax. Anyone can help? Anyone can help?

so sad, so sad … Your custom function returns the same output as `F.softmax`:

``````x = torch.randn(5, 10)
output = F.softmax(x, 1)

maxes = torch.max(x, 1, keepdim=True)
x_exp = torch.exp(x-maxes)
x_exp_sum = torch.sum(x_exp, 1, keepdim=True)
output_custom = x_exp/x_exp_sum

print(torch.allclose(output, output_custom))
> True
print(torch.sum(torch.abs(output-output_custom)))
> tensor(2.3108e-7)
``````
2 Likes

The output from your `own_softmax` is slightly different from `torch.nn.functional.softmax` .
This may be the reason why your `own_softmax` degrades the performance.

``````x = torch.randn(2,10)
h_own = own_softmax(x)
h = torch.nn.functional.softmax(x, 1)
print(h - h_own)
``````
1 Like

Thank you so much. It seems that different centralization method for the network output score influence the softmax output a lot. I have tested the centralization using max value and mean value, and their output are quite different. I am wondering whether the mean one is more stable？

1 Like

Thank you so much for your reply.

If I swap `torch.max` for `torch.mean`, I get approx. the same accuracy.
Could you post the code you’ve used to check the behavior?