Apply mask softmax

Thanks a lot!
And I have a little suggestion:
A_softmax = A_exp /(torch.sum(A_exp,dim=1,keepdim=True)+epsilon)
It can avoid division by zero zero.

1 Like