PyTorch: Scattering a 3D tensor

I have the following 3D tensor (tensor_3D), which I want to scatter (product) over the first two dimensions using the index (index).

import torch
# create a 3D tensor of size 3x3x3 with numbers from 1 to 27 
tensor_3D = torch.arange(1,28).reshape(3,3,3) 
#tensor([[[ 1,  2,  3],
#         [ 4,  5,  6],
#         [ 7,  8,  9]],
#
#        [[10, 11, 12],
#         [13, 14, 15],
#         [16, 17, 18]],
#
#        [[19, 20, 21],
#         [22, 23, 24],
#         [25, 26, 27]]])

# create a 1D tensor of size 3 to be used as index when scattering
index = torch.tensor([0,0,1]).reshape(3,1) 
#tensor([[0],
#        [0],
#        [1]])

This is my desired tensor is:

# permute the result to get the desired output
d1 = torch.tensor ([4,  10,  18, 7,   8,   9 ]).reshape(2,3)
d2 = torch.tensor ([130, 154, 180, 16,  17,  18 ]).reshape(2,3)
d3 = torch.tensor ([418, 460, 504, 25,  26,  27 ]).reshape(2,3)
desired = torch.stack([d1,d2,d3], dim=2)
#tensor([[[  4, 130, 418],
#         [ 10, 154, 460],
#         [ 18, 180, 504]],
#
#        [[  7,  16,  25],
#         [  8,  17,  26],
#         [  9,  18,  27]]])

As of now, I am doing the following but I am not getting the correct dimensions

import torch 
# create a 3D tensor of size 3x3x3 with numbers from 1 to 27 uisng torch.arange
tensor_3D = torch.arange(1,28).reshape(3,3,3) 
# create a 1D tensor of size 3 with numbers 
index = torch.tensor([0,0,1]).reshape(3,1) 
#Expand index to 3x3x3
index_expand = index.expand(3,3,3)
print(index_expand)
# use index to scatter_reduce product the values of tensor_3D to a new tensor of size 3x2x3
new_tensor = torch.ones(3,2,3, dtype=torch.int64)
print(new_tensor)
res = new_tensor.scatter_reduce(
    1,
    index_expand, 
    tensor_3D, 
    reduce="prod")

res.permute(1, 2, 0) should work and would create your desired output.

1 Like

Thank you, @ptrblck, for your answer. This actually produces the desired output. However, I would like to understand how scatter_reduce() works for 3D tensors. Could you help me to find a solution that does not require this ex-post modification (.permute()) to find the desired output?

To be concrete, I would like to know what kind of index I need to provide to modify the new_tensor object if it were 2x3x3 from the beginning (differently from my attempted solution where it starts being 3x2x3 and then it was permute() ex-post to 2x3x3)

Additionally, I couldnā€™t find examples for scatter_reduce() in the documentation where tensors larger than two dimensions were used. Am I simply not looking at the right places or there are they simply not provided?

Thank you a lot for your time.