I am looking for the most efficient way to split a tensor into multiple different tensors based on a 1-D index tensor.
Here is the full problem description:
I have a point cloud tensor with dimension (batchsize, 3, 2048) and an index tensor of size (batchsize, 2048) (LongTensor). This index array contains integer values from 0 to 3 (inclusive). I want to split the point cloud based on those indices such that I end up with 4 different tensors of shape (batchsize, 3, N_i) where the sum over all 4 different N_i equals 2048.
I’ve looked into torch.index_select but I can’t figure out, how I could use this function to my advantage.
Your help would be much appreciated!