I would like to expand a 2d tensor to a 3d tensor, varying which dimension is expanded in each row of the tensor. For example, I’d like to take this 2d tensor:
values = torch.tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8],
[9, 10]])
And transform it to this 3d tensor:
tensor([[[ 1, 1],
[ 2, 2]],
[[ 3, 4],
[ 3, 4]],
[[ 5, 5],
[ 6, 6]],
[[ 7, 7],
[ 8, 8]],
[[ 9, 10],
[ 9, 10]]])
where the dimension to expand for each row is given by these indices:
torch.tensor([0, 1, 0, 0, 1])
The following code produces the desired result, but is inefficient:
def expand_slow(expand_indices, values):
expanded_values = torch.stack([
row_val.expand([2,-1]).transpose(expand_indices[row_idx], 1)
for row_idx, row_val in enumerate(torch.unbind(values, dim=0))
], dim=0)
return expanded_values
expand_indices = torch.tensor([0, 1, 0, 0, 1])
values = torch.tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8],
[9, 10]])
expand_slow(expand_indices, values)
What’s the best way to perform this operation using efficient pytorch functions? Thank you!