Hierarchical softmax in pytorch

I’m aware about the softmax function in pytorch. However, when using it, I run into computation complexity problems because of the normalising factor in the denominator in the softmax function. The reason is because of too many classes in my classification. I can not use negative sampling instead of softmax, because the variables in the normalising factor are chnged during the optimisation. We just can not ignore that normalising term.

How can I use hierarchical softmax instead, in order to approximate the softmax version. Thank you so much.

1 Like