Best alternative to torch.tensor_split without upgrading to 1.8

Hi all,

Looking for the exact functionality provided by torch.tensor_split in v1.8 without having to upgrade to 1.8 :wink:

predicted_ids = torch.tensor([100, 100, 31090, 100, 100, 100, 31090])
indices = torch.tensor([2,6])
# returns [tensor([100, 100]), tensor([31090,   100,   100,   100]), tensor([31090])]

I would like the last line of the above, except via pytorch operations so I can avoid breaking the computational graph.

Is there a succinct way to do this?

isn’t it the same as torch.split(x, int_list). You can do int_list=indices.tolist() if you have a reason to feed in a tensor of integers.

@googlebot Thanks. From what I could tell, torch.split accepts chunk sizes rather than indices. For example to split x at indices 5, 10, and 15 I think you would do torch.split(x,[5,5,5]). So I just ended up doing

chunk_sizes = [len(x) for x in np.array_split(predicted_ids,indices)]
desired_output = torch.split(predicted_ids,chunk_sizes)

If there is a simpler way I would definitely be interested in hearing!

Ah, right. Then something like following would calculate sizes

>[2, 1, 97]

i.e. append total size, prepend zero if you have split points instead of indexes, take deltas.

Of course, thanks! Also saw this solution (essentially the same) just using indexing.