# Reshape to several dimensions using indices matrix

In a project, I have a tensor that is in the shape [N, M] and I would like to reshape it to a [I, J, M] while reordering using another tensor. I am not sure how to get this result. I have looked into gather, index_select but it doesn’t seem to fit my case. (I did a 2D case below, but in my project I would like to apply this technique to a 3D case [N, M] => [I, J, K, M])

``````import torch

x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
indices = torch.tensor([[0, 1], [1, 0], [0, 0], [1, 1]])
``````

indices of the 1st column => index along the first dimension
indices of the 2nd column => index along the second dimension

The tensors above would give:

``````output = torch.function(x, indices)
output = [
[
[7, 8, 9],
[1, 2, 3],
],
[
[4, 5, 6],
[10, 11, 12]
]
]
``````

Thank you!

Edit:
I found a solution …
But it is quite ugly, and I would like to know if someone think of a better way to handle that:

``````# https://stackoverflow.com/questions/66832716/how-to-quickly-inverse-a-permutation-by-using-pytorch
def inverse_permutation(perm):
def faster_inverse_permutation(perm):
inv = torch.empty_like(perm)
inv[perm] = torch.arange(perm.size(0), device=perm.device)
return inv

values = 2**torch.arange(start=1, end=-1, step=-1)
print(values)
flattened_indices = torch.sum(indices*values, axis=1)
print(flattened_indices)
flattened_indices_inverted = faster_inverse_permutation(flattened_indices)
print(flattened_indices_inverted)
x_permuted = x[flattened_indices_inverted]
print(x_permuted)
x_permuted_reshaped = x_permuted .reshape(*(torch.max(indices, axis=0)[0] + 1).tolist(), -1)
print(x_permuted_reshaped)
``````
``````tensor([2, 1])
tensor([1, 2, 0, 3])
tensor([2, 0, 1, 3])
tensor([[ 7,  8,  9],
[ 1,  2,  3],
[ 4,  5,  6],
[10, 11, 12]])
tensor([[[ 7,  8,  9],
[ 1,  2,  3]],

[[ 4,  5,  6],
[10, 11, 12]]])
``````
``````def custom_reshaping(x, indices):
shape = tuple((*(torch.max(indices, axis=0)[0] + 1).tolist(), -1))
values = 2**torch.arange(start=indices.shape[1]-1, end=-1, step=-1)
flattened_indices = torch.sum(indices*values, axis=1)
flattened_indices_inverted = faster_inverse_permutation(flattened_indices)
x_permuted = x[flattened_indices_inverted]
x_permuted_reshaped = x_permuted.reshape(shape)
return x_permuted_reshaped

custom_reshaping(x, indices)
``````

This should alternatively work:

``````out = torch.zeros(indices[:, 0].max()+1, indices[:, 1].max()+1, x[0].size(0)).long()
out[indices[:, 0], indices[:, 1]] = x
print(out)
> tensor([[[ 7,  8,  9],
[ 1,  2,  3]],

[[ 4,  5,  6],
[10, 11, 12]]])
``````
1 Like

Oh this is so much cleaner!
I will certainly soon try to time it to see performance, because there is way less operations, but there is creation of a big Tensor. Also I’ll put a comment if I manage to have a variation with a variable number of indices dimension (even though I think for my project it is alright because I think I only need for 3 dimensions)

Thank you a lot!