I am trying to permute the feature map dimension in a tensor. As a very simplified case,
If I have a tensor of size (5, 4, 3, 6)
I want to rearrange the above tensor along its dimension 1 (i.e. 4) from 0,1,2,3 to 0,2,1,3
One possible way I found was to do a index_select followed by cat. But, in case of a larger tensor, a lot of intermediate tensors will have to be created.
>>> n = np.arange(480).reshape((5,4,4,6))
>>> a = torch.from_numpy(n)
>>> perm = torch.LongTensor([0,2,1,3])
>>> a[:, perm, :, :]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: indexing a tensor with an object of type LongTensor. The only supported types are integers, slices, numpy scalars and torch.LongTensor or torch.ByteTensor as the only argument.
Shouldn’t perm be a torch.LongTensor? Am I missing something here?
It will work, but you need to remove the last :, :. They don’t do anything, and for the moment we require LongTensors to appear as the last elements in the index.
Just a FYI: this is now working. However, I would be interested in a case where the permutation is different for each row (first axis)… any cool advanced indexing trick that would do it?
This answer is a bit late but I stumbled upon this thread while trying to create a permutation matrix from permutation indexes so I’ll just share what I found works.
Regarding how to permute vectors in PyTorch: there seems to be a function pytorch.permute(), but I can’t find any documentation for it, and when I try it doesn’t seem to work as I might expect (seems to be a no-op).
Oh, I see. I guess i must have found the wrong version of the documentation when I searched online. I was giving a second arg which was a LongTensor that I was expecting it to use for indexing, but which I now see was ignored.
x = torch.randn(3,3)
x
tensor([[-1.4109, -0.0597, -0.8855],
[-0.6355, -0.6556, -1.9610],
[ 1.0115, 1.5676, 0.8374]])
x.permute((0,1), torch.LongTensor([2, 1, 0]))
gives the same as x.
I see now that what I was trying to do, you are supposed to do using indexing operators.