Selecting elements in Tensor with differing amount of dimensions using another Tensor with indicies

So i have a tensor with a certain amount of dimensions A filled with values and a Tensor C with the same dimensions (but without the last) with indices for A’s last dimension. I want to pick the values of A based on C such that a Tensor with the shape of C is returned.

Example:

A has shape (7, 2, 4), C has shape (7, 2).

A:

    [[[-10.0491,  -4.9780,  -3.0346,  -1.0746],
     [  1.1812,   1.8627,   3.2540,   5.5354]],

    [[ -9.9464,  -4.9588,  -2.9927,  -0.8602],
     [  1.0148,   1.9898,   2.7656,   4.7714]],

    [[ -9.7778,  -5.0038,  -3.0378,  -1.0750],
     [  1.0365,   2.1866,   2.8971,   4.8669]],

    [[-10.1701,  -4.8115,  -3.0066,  -1.0485],
     [  0.7645,   1.8798,   3.0735,   4.7153]],

    [[ -9.8422,  -4.9419,  -3.1802,  -0.8486],
     [  0.8864,   1.7320,   3.3117,   4.8196]],

    [[-10.0794,  -4.9817,  -2.9976,  -0.9717],
     [  0.6711,   1.9173,   2.8586,   4.9084]],

    [[-10.0262,  -5.1335,  -2.9970,  -0.9397],
     [  1.2652,   1.9704,   3.0415,   5.8505]]]

I:

   [[0, 1],
    [0, 3],
    [0, 1],
    [0, 3],
    [1, 0],
    [2, 1],
    [2, 1]]

The wanted resulting tensor B with shape (7, 2) is:

    [[-10.0491,   1.8627],
    [ -9.9464,   4.7714],
    [ -9.7778,   2.1866],
    [-10.1701,   4.7153],
    [ -4.9419,   0.8864],
    [ -2.9976,   1.9173],
    [ -2.9970,   1.9704]]

I want this to be possible for A being n-dimensional and C being (n-1)-dimensional.
I had this solution for n=3:

b = a[torch.arange(c.shape[0]).unsqueeze(1), torch.arange(c.shape[1]), c]

This will obviously cause problems for a differnt amount of dimensions.
Thank you for help!

gather should work:

A = torch.tensor([[[-10.0491,  -4.9780,  -3.0346,  -1.0746],
                   [  1.1812,   1.8627,   3.2540,   5.5354]],
             
                  [[ -9.9464,  -4.9588,  -2.9927,  -0.8602],
                   [  1.0148,   1.9898,   2.7656,   4.7714]],
            
                  [[ -9.7778,  -5.0038,  -3.0378,  -1.0750],
                   [  1.0365,   2.1866,   2.8971,   4.8669]],
            
                  [[-10.1701,  -4.8115,  -3.0066,  -1.0485],
                   [  0.7645,   1.8798,   3.0735,   4.7153]],
            
                  [[ -9.8422,  -4.9419,  -3.1802,  -0.8486],
                   [  0.8864,   1.7320,   3.3117,   4.8196]],
            
                  [[-10.0794,  -4.9817,  -2.9976,  -0.9717],
                   [  0.6711,   1.9173,   2.8586,   4.9084]],
            
                  [[-10.0262,  -5.1335,  -2.9970,  -0.9397],
                   [  1.2652,   1.9704,   3.0415,   5.8505]]])

C = torch.tensor([[0, 1],
                  [0, 3],
                  [0, 1],
                  [0, 3],
                  [1, 0],
                  [2, 1],
                  [2, 1]])

ref = torch.tensor([[-10.0491,   1.8627],
                    [ -9.9464,   4.7714],
                    [ -9.7778,   2.1866],
                    [-10.1701,   4.7153],
                    [ -4.9419,   0.8864],
                    [ -2.9976,   1.9173],
                    [ -2.9970,   1.9704]])

res = A.gather(2, C.unsqueeze(2)).squeeze(2)

print((res == ref).all())
# tensor(True)
1 Like

Perfect this is what i was looking for. :blush: For other than three dimensions:

A.gather(-1, C.unsqueeze(-1)).squeeze(-1)