def softmax(input, axis=1):
input_size = input.size()
trans_input = input.transpose(axis, len(input_size)-1)
trans_size = trans_input.size()
input_2d = trans_input.contiguous().view(-1, trans_size[-1])
soft_max_2d = F.softmax(input_2d)
soft_max_nd = soft_max_2d.view(*trans_size)
return soft_max_nd.transpose(axis, len(input_size)-1)
13 Likes