I need to sort a 2D tensor and apply the same sorting to a 3D tensor.
I tried direct indexing but that does not work with multidimensional indexes. I also tried gather but it does not work because the index and the souce do not have the same dimensionality.
Any ideas?
t = torch.tensor([[1,3,2],
[4,2,3]])
t2 = torch.tensor([[[1.1,1.2],[3.1,3.2],[2.1,2.2]],
[[4.1,4.2],[2.1,2.2],[3.1,3.2]]])
t2sorted = torch.tensor([[[1.1,1.2],[2.1,2.2],[3.1,3.2]],
[[2.1,2.2],[3.1,3.2],[4.1,4.2]]])
s,idx = t.sort(-1)
# direct indexing
npt.assert_equal(t2sorted,t2[idx].numpy())
# RuntimeError: invalid argument 2: out of range: 12 out of 12 at c:\programdata\miniconda3\conda-bld\pytorch-cpu_1524541161962\work\aten\src\th\generic/THTensorMath.c:430
# gather
npt.assert_equal(t2sorted,t2.gather(1,idx).numpy())
# RuntimeError: invalid argument 4: Index tensor must have same dimensions as input tensor at c:\programdata\miniconda3\conda-bld\pytorch-cpu_1524541161962\work\aten\src\th\generic/THTensorMath.c:581
# if t2 has same shape as t it gather will work
t2 = torch.tensor([[2,4,3],
[5,3,4]])
s,idx = t.sort(-1)
npt.assert_equal(s,(t2.gather(1,idx)-1).numpy())