Applying sort index to tensor with more dimenstions

I need to sort a 2D tensor and apply the same sorting to a 3D tensor.
I tried direct indexing but that does not work with multidimensional indexes. I also tried gather but it does not work because the index and the souce do not have the same dimensionality.

Any ideas?

    t = torch.tensor([[1,3,2],
                      [4,2,3]])
    t2 = torch.tensor([[[1.1,1.2],[3.1,3.2],[2.1,2.2]],
                       [[4.1,4.2],[2.1,2.2],[3.1,3.2]]])

    t2sorted = torch.tensor([[[1.1,1.2],[2.1,2.2],[3.1,3.2]],
                             [[2.1,2.2],[3.1,3.2],[4.1,4.2]]])

    s,idx = t.sort(-1)

    # direct indexing
    npt.assert_equal(t2sorted,t2[idx].numpy())
    # RuntimeError: invalid argument 2: out of range: 12 out of 12 at c:\programdata\miniconda3\conda-bld\pytorch-cpu_1524541161962\work\aten\src\th\generic/THTensorMath.c:430

    # gather
    npt.assert_equal(t2sorted,t2.gather(1,idx).numpy())
    # RuntimeError: invalid argument 4: Index tensor must have same dimensions as input tensor at c:\programdata\miniconda3\conda-bld\pytorch-cpu_1524541161962\work\aten\src\th\generic/THTensorMath.c:581


    # if t2 has same shape as t it gather will work
    t2 = torch.tensor([[2,4,3],
                       [5,3,4]])
    s,idx = t.sort(-1)
    npt.assert_equal(s,(t2.gather(1,idx)-1).numpy())

You need to first sort each of the nested arrays first, and then sort by the first slice. Assuming t2 sub arrays are already sorted:

sort_by = np.argsort(t2[:, 0])
t2sorted = t2[sort_by]

Thanks for looking into this!
Your example is not quite what I need to do. My sort order for the first two dimensions is already defined by the sort order of t stored in the idx. I need to apply idx to t2 with the goal of getting t2sorted. t2sorted in the example is the expected result for purpose of testing.

You need to extract the slice from t2 that contains: 1.1,3.1,2.1,4.1,2.1,3.1 and then sort by these. I didn’t check my code :slight_smile:

Here is the solution I came up with. It is ugly but does not need a loop:

    nLastDim = t2.shape[-1]
    nLast2Dim = t2.shape[-2]
    nLast3Dim = t2.shape[-3]
    lastDimCounter = torch.arange(0,nLastDim,dtype=torch.long)
    last3DimCounter = torch.arange(0,nLast3Dim,dtype=torch.long)
    t2 = t2.reshape(-1)[(idx*nLastDim+(last3DimCounter*nLastDim*nLast2Dim).unsqueeze(-1)).unsqueeze(-1).expand(-1,-1,nLastDim) + lastDimCounter]

Shape returns a list, you can do that in one call.
I.e.,
ndim = t2.shape
Where ndim[0] = nLast3Dim in your code.

You can use the sort function on the slice as I mentioned for cleaner implementation. If you still need help I’ll do it when I get to a computer :smile: