Why softmax function can't specify the dimension to operate

It’s so inconvenient…

4 Likes
def softmax(input, axis=1):
    input_size = input.size()
    
    trans_input = input.transpose(axis, len(input_size)-1)
    trans_size = trans_input.size()

    input_2d = trans_input.contiguous().view(-1, trans_size[-1])
    
    soft_max_2d = F.softmax(input_2d)
    
    soft_max_nd = soft_max_2d.view(*trans_size)
    return soft_max_nd.transpose(axis, len(input_size)-1)
13 Likes

Thanks for this snippet!
Just in case, make sure you do import torch.nn.functional as F first :slight_smile:

1 Like