import torch
import torch.nn.functional as F
def softmax(input, axis=1):
"""
Apply softmax on input at certain axis.
Parammeters:
----------
input: Tensor (N*L or rank>2)
axis: the axis to apply softmax
Returns: Tensor with softmax applied on that dimension.
"""
input_size = input.size()
trans_input = input.transpose(axis, len(input_size)-1)
trans_size = trans_input.size()
input_2d = trans_input.view(-1, trans_size[-1])
soft_max_2d = F.softmax(input_2d)
soft_max_nd = soft_max_2d.view(*trans_size)
return soft_max_nd.transpose(axis, len(input_size)-1)
aa= torch.randn(3,4,4)
print aa
soft_1 = softmax(aa, axis = 1)
print soft_1
gives the following error:
File "/local/anaconda2/lib/python2.7/site-packages/torch/tensor.py", line 214, in view
raise ValueError("input should be contiguous")
ValueError: input should be contiguous