How to slice a tensor based on another tensor

for example, I have a target tensor like this, image
and I have a length tensor [5,3,2],
if I directly flatten the tensor on dimension 2, the shape will be [1,3,15]
How to flatten the tensor based on the length of dimension 3 so that the result can be [1,3,10] where 10 = 5+3+2?
The desired output will be:
tensor([[[0.4076, 0.2769, 0.2220, 0.4076, 0.3851, 0.3649, 0.3190, 0.4107,
0.3881, 0.3823, ],
[0.3337, 0.3804, 0.3061, 0.3337, 0.3179, 0.3422, 0.3616, 0.3879,
0.3302, 0.3508, ],
[0.2588, 0.3427, 0.4719, 0.2588, 0.2970, 0.2930, 0.3194, 0.2014,
0.2817, 0.2669, ]]],
device=‘cuda:0’, grad_fn=)

Thanks!

I’m a bit confused about the mentioned shapes.

This shouldn’t be possible via reshaping the tensor as the former one has 5*3*2=30 elements while the latter has 1*3*15=45.

“Flattening” dimensions would be a product of their shapes.
If you want to create a tensor in the shape [1, 3, 10] by flattening dim0 with dim2 you could use:

x = torch.randn(5, 3, 2)
y = x.permute(1, 0, 2).contiguous().view(1, 3, -1)
print(y.shape)
> torch.Size([1, 3, 10])

tensor you show has shape (1,3,3,5)

You can imagine it as a 1x parallelogram with dimensions 3x3x5
Flattening is pretty much ‘scanning’ it line by line, so you can flatten (3,3,5) to (3,15)

torch.flatten(x, -2, -1) will flatten your 1x3x3x5 tensor to 1x3x15 you were looking for (other option is to use view)

Thanks for your reply! I am sorry for the confusion. What I meant is that I could easily flatten the orginal tensor(1,3,3,5) to a (1,3,15) tensor. But I only want to keep some of the values in the original tensor like this:

. So that the last ouput can be (1,3, 10)

You can slice specific parts then concatenate them to one tensor

>>> a = torch.rand((1,3,3,5))
tensor([[[[0.1562, 0.1055, 0.5032, 0.0966, 0.2088],
          [0.8520, 0.3314, 0.6602, 0.4182, 0.7371],
          [0.5483, 0.2338, 0.9695, 0.6638, 0.9435]],

         [[0.2520, 0.9377, 0.7031, 0.9447, 0.9292],
          [0.3504, 0.1614, 0.0280, 0.7294, 0.4252],
          [0.6095, 0.9667, 0.6628, 0.9642, 0.2766]],

         [[0.8955, 0.5914, 0.3912, 0.8431, 0.9219],
          [0.4871, 0.5351, 0.9024, 0.3061, 0.1959],
          [0.2931, 0.2810, 0.7195, 0.8711, 0.6927]]]])
>>> c = a[:,:,0,:]
tensor([[[0.1562, 0.1055, 0.5032, 0.0966, 0.2088],
         [0.2520, 0.9377, 0.7031, 0.9447, 0.9292],
         [0.8955, 0.5914, 0.3912, 0.8431, 0.9219]]])
>>> d = a[:,:,1,:3]
tensor([[[0.8520, 0.3314, 0.6602],
         [0.3504, 0.1614, 0.0280],
         [0.4871, 0.5351, 0.9024]]])
>>> e = a[:,:,2,:2]
tensor([[[0.5483, 0.2338],
         [0.6095, 0.9667],
         [0.2931, 0.2810]]])
>>> f = torch.cat((c,d,e),2)
tensor([[[0.1562, 0.1055, 0.5032, 0.0966, 0.2088, 0.8520, 0.3314, 0.6602,
          0.5483, 0.2338],
         [0.2520, 0.9377, 0.7031, 0.9447, 0.9292, 0.3504, 0.1614, 0.0280,
          0.6095, 0.9667],
         [0.8955, 0.5914, 0.3912, 0.8431, 0.9219, 0.4871, 0.5351, 0.9024,
          0.2931, 0.2810]]])

I think slices and torch.cat should keep the gradients