Split tensor based on segments



I am looking for the most efficient way to split a tensor into multiple different tensors based on a 1-D index tensor.

Here is the full problem description:

I have a point cloud tensor with dimension (batchsize, 3, 2048) and an index tensor of size (batchsize, 2048) (LongTensor). This index array contains integer values from 0 to 3 (inclusive). I want to split the point cloud based on those indices such that I end up with 4 different tensors of shape (batchsize, 3, N_i) where the sum over all 4 different N_i equals 2048.

I’ve looked into torch.index_select but I can’t figure out, how I could use this function to my advantage.

Your help would be much appreciated!


I’m not sure to understand the use case completely.
Could you post some examples using smaller tensor shapes, e.g. [1, 3, 10], and demonstrate how you would like to split this tensor using the indices?