Index and replace into 3D tensor according to list of indices

Hi!

I have a problem where I need to select vectors from a three dimensional tensor based on a list of 2D indices. However, these indices may select different numbers of vectors per dimension.

To better clarify this problem, take the following example:

A = torch.FloatTensor( # 3D data tensor
    [
        [
            [0, 0, 0, 0],
            [1, 1, 1, 1]
        ],
        [
            [2, 2, 2, 2],
            [3, 3, 3, 3]
        ],
        [
            [4, 4, 4, 4],
            [5, 5, 5, 5]
        ]
    ]
)
B = [(0, 0), (0, 1), (1, 0)] # 2D coordinate tensor

C = IndexSelect(A, B) # An unknown selection function
"""
C: [
           [0, 0, 0, 0],
           [1, 1, 1, 1],
           [2, 2, 2, 2]
    ]
"""

So C contains all vectors from A corresponding to indices from B. Note how the indices in B may select an arbitrary amount of vectors from different dimensions, and can also skip them (e,g., B skips all vectors in A[2]).

Is there a way to create such an “IndexSelect” function?

Moreover, is it possible to do this with the ability to then replace selected vectors with others?

For instance, piggy-backing on the first example,

D = [
           [6, 6, 6, 6],
           [7, 7, 7, 7],
           [8, 8, 8, 8]
    ]
A = Replace(A, B, D)
"""
A: [
        [
            [6, 6, 6, 6],
            [7, 7, 7, 7]
        ],
        [
            [8, 8, 8, 8],
            [3, 3, 3, 3]
        ],
        [
            [4, 4, 4, 4],
            [5, 5, 5, 5]
        ]
    ]
"""

Where vectors in A are replaced (in the “Replace” function) by those in D according to indices in B.

I’ve looked at selection functions like torch.gather or torch.index_select, but my understanding is that those assume a user selects the same number of elements along the specified dimension. However in this case we may select arbitrary elements along each dimension.

Anything would be very greatly appreciated! Thanks for your time!

I don’t fully understand how the posted examples select “different numbers of vectors per dimension”, as it seems that B is used to index entire “rows” of A. Could you explain where the variable length is used in your example?