In a project, I have a tensor that is in the shape [N, M] and I would like to reshape it to a [I, J, M] while reordering using another tensor. I am not sure how to get this result. I have looked into gather, index_select but it doesn’t seem to fit my case. (I did a 2D case below, but in my project I would like to apply this technique to a 3D case [N, M] => [I, J, K, M])
import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
indices = torch.tensor([[0, 1], [1, 0], [0, 0], [1, 1]])
indices of the 1st column => index along the first dimension
indices of the 2nd column => index along the second dimension
…
The tensors above would give:
output = torch.function(x, indices)
output = [
[
[7, 8, 9],
[1, 2, 3],
],
[
[4, 5, 6],
[10, 11, 12]
]
]
Thank you!
Edit:
I found a solution …
But it is quite ugly, and I would like to know if someone think of a better way to handle that:
# https://stackoverflow.com/questions/66832716/how-to-quickly-inverse-a-permutation-by-using-pytorch
def inverse_permutation(perm):
return torch.argsort(perm)
def faster_inverse_permutation(perm):
inv = torch.empty_like(perm)
inv[perm] = torch.arange(perm.size(0), device=perm.device)
return inv
values = 2**torch.arange(start=1, end=-1, step=-1)
print(values)
flattened_indices = torch.sum(indices*values, axis=1)
print(flattened_indices)
flattened_indices_inverted = faster_inverse_permutation(flattened_indices)
print(flattened_indices_inverted)
x_permuted = x[flattened_indices_inverted]
print(x_permuted)
x_permuted_reshaped = x_permuted .reshape(*(torch.max(indices, axis=0)[0] + 1).tolist(), -1)
print(x_permuted_reshaped)
tensor([2, 1])
tensor([1, 2, 0, 3])
tensor([2, 0, 1, 3])
tensor([[ 7, 8, 9],
[ 1, 2, 3],
[ 4, 5, 6],
[10, 11, 12]])
tensor([[[ 7, 8, 9],
[ 1, 2, 3]],
[[ 4, 5, 6],
[10, 11, 12]]])
def custom_reshaping(x, indices):
shape = tuple((*(torch.max(indices, axis=0)[0] + 1).tolist(), -1))
values = 2**torch.arange(start=indices.shape[1]-1, end=-1, step=-1)
flattened_indices = torch.sum(indices*values, axis=1)
flattened_indices_inverted = faster_inverse_permutation(flattened_indices)
x_permuted = x[flattened_indices_inverted]
x_permuted_reshaped = x_permuted.reshape(shape)
return x_permuted_reshaped
custom_reshaping(x, indices)