I have the following 3D tensor (tensor_3D
), which I want to scatter (product) over the first two dimensions using the index (index
).
import torch
# create a 3D tensor of size 3x3x3 with numbers from 1 to 27
tensor_3D = torch.arange(1,28).reshape(3,3,3)
#tensor([[[ 1, 2, 3],
# [ 4, 5, 6],
# [ 7, 8, 9]],
#
# [[10, 11, 12],
# [13, 14, 15],
# [16, 17, 18]],
#
# [[19, 20, 21],
# [22, 23, 24],
# [25, 26, 27]]])
# create a 1D tensor of size 3 to be used as index when scattering
index = torch.tensor([0,0,1]).reshape(3,1)
#tensor([[0],
# [0],
# [1]])
This is my desired tensor is:
# permute the result to get the desired output
d1 = torch.tensor ([4, 10, 18, 7, 8, 9 ]).reshape(2,3)
d2 = torch.tensor ([130, 154, 180, 16, 17, 18 ]).reshape(2,3)
d3 = torch.tensor ([418, 460, 504, 25, 26, 27 ]).reshape(2,3)
desired = torch.stack([d1,d2,d3], dim=2)
#tensor([[[ 4, 130, 418],
# [ 10, 154, 460],
# [ 18, 180, 504]],
#
# [[ 7, 16, 25],
# [ 8, 17, 26],
# [ 9, 18, 27]]])
As of now, I am doing the following but I am not getting the correct dimensions
import torch
# create a 3D tensor of size 3x3x3 with numbers from 1 to 27 uisng torch.arange
tensor_3D = torch.arange(1,28).reshape(3,3,3)
# create a 1D tensor of size 3 with numbers
index = torch.tensor([0,0,1]).reshape(3,1)
#Expand index to 3x3x3
index_expand = index.expand(3,3,3)
print(index_expand)
# use index to scatter_reduce product the values of tensor_3D to a new tensor of size 3x2x3
new_tensor = torch.ones(3,2,3, dtype=torch.int64)
print(new_tensor)
res = new_tensor.scatter_reduce(
1,
index_expand,
tensor_3D,
reduce="prod")