Split tensor based on segments

Hi!

I am looking for the most efficient way to split a tensor into multiple different tensors based on a 1-D index tensor.

Here is the full problem description:

I have a point cloud tensor with dimension (batchsize, 3, 2048) and an index tensor of size (batchsize, 2048) (LongTensor). This index array contains integer values from 0 to 3 (inclusive). I want to split the point cloud based on those indices such that I end up with 4 different tensors of shape (batchsize, 3, N_i) where the sum over all 4 different N_i equals 2048.

I’ve looked into torch.index_select but I can’t figure out, how I could use this function to my advantage.

Your help would be much appreciated!

I’m not sure to understand the use case completely.
Could you post some examples using smaller tensor shapes, e.g. [1, 3, 10], and demonstrate how you would like to split this tensor using the indices?

Coming here having a similar question. If I understood the OP correctly, given a tensor A of
(1, 3, 5) and a corresponding tensor B of shape (5,) where for example

A=
[[[3, 0, 0, 1, 8, 9, 5, 4, 1, 6],
[0, 0, 0, 8, 0, 7, 1, 0, 0, 8],
[1, 0, 2, 5, 3, 6, 3, 2, 4, 7]]]

and B=[0, 0, 1, 1, 0, 2, 2, 3, 3, 3]

the goal is to get 4 tensors where the first contains all elements of A whose corresponding B is 0.
Thus
[[[3, 0, 8],
[0, 0, 0],
[1, 0, 3]]]

the second one would be
[[[0, 1],
[0, 8],
[2, 5]]]

and so forth.

Is there an easy was of doing it. And to answer my actual question, is there also a way to generalize it to arbitrary index values and then split in a similar fashion such that each resulting tensor contains values with the same corresponding index?