Reshape to several dimensions using indices matrix

In a project, I have a tensor that is in the shape [N, M] and I would like to reshape it to a [I, J, M] while reordering using another tensor. I am not sure how to get this result. I have looked into gather, index_select but it doesn’t seem to fit my case. (I did a 2D case below, but in my project I would like to apply this technique to a 3D case [N, M] => [I, J, K, M])

import torch

x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
indices = torch.tensor([[0, 1], [1, 0], [0, 0], [1, 1]])

indices of the 1st column => index along the first dimension
indices of the 2nd column => index along the second dimension

The tensors above would give:

output = torch.function(x, indices)
output = [
    [
        [7, 8, 9],
        [1, 2, 3],
    ],
    [
        [4, 5, 6],
        [10, 11, 12]
    ]
]

Thank you!

Edit:
I found a solution …
But it is quite ugly, and I would like to know if someone think of a better way to handle that:

# https://stackoverflow.com/questions/66832716/how-to-quickly-inverse-a-permutation-by-using-pytorch
def inverse_permutation(perm):
    return torch.argsort(perm)
def faster_inverse_permutation(perm):
    inv = torch.empty_like(perm)
    inv[perm] = torch.arange(perm.size(0), device=perm.device)
    return inv

values = 2**torch.arange(start=1, end=-1, step=-1)
print(values)
flattened_indices = torch.sum(indices*values, axis=1)
print(flattened_indices)
flattened_indices_inverted = faster_inverse_permutation(flattened_indices)
print(flattened_indices_inverted)
x_permuted = x[flattened_indices_inverted]
print(x_permuted)
x_permuted_reshaped = x_permuted .reshape(*(torch.max(indices, axis=0)[0] + 1).tolist(), -1)
print(x_permuted_reshaped)
tensor([2, 1])
tensor([1, 2, 0, 3])
tensor([2, 0, 1, 3])
tensor([[ 7,  8,  9],
        [ 1,  2,  3],
        [ 4,  5,  6],
        [10, 11, 12]])
tensor([[[ 7,  8,  9],
         [ 1,  2,  3]],

        [[ 4,  5,  6],
         [10, 11, 12]]])
def custom_reshaping(x, indices):
    shape = tuple((*(torch.max(indices, axis=0)[0] + 1).tolist(), -1))
    values = 2**torch.arange(start=indices.shape[1]-1, end=-1, step=-1)
    flattened_indices = torch.sum(indices*values, axis=1)
    flattened_indices_inverted = faster_inverse_permutation(flattened_indices)
    x_permuted = x[flattened_indices_inverted]
    x_permuted_reshaped = x_permuted.reshape(shape)
    return x_permuted_reshaped

custom_reshaping(x, indices)

This should alternatively work:

out = torch.zeros(indices[:, 0].max()+1, indices[:, 1].max()+1, x[0].size(0)).long()
out[indices[:, 0], indices[:, 1]] = x
print(out)
> tensor([[[ 7,  8,  9],
           [ 1,  2,  3]],

          [[ 4,  5,  6],
           [10, 11, 12]]])
1 Like

Oh this is so much cleaner!
I will certainly soon try to time it to see performance, because there is way less operations, but there is creation of a big Tensor. Also I’ll put a comment if I manage to have a variation with a variable number of indices dimension (even though I think for my project it is alright because I think I only need for 3 dimensions)

Thank you a lot!