Looking for the exact functionality provided by torch.tensor_split in v1.8 without having to upgrade to 1.8
predicted_ids = torch.tensor([100, 100, 31090, 100, 100, 100, 31090])
indices = torch.tensor([2,6])
# returns [tensor([100, 100]), tensor([31090, 100, 100, 100]), tensor()]
I would like the last line of the above, except via pytorch operations so I can avoid breaking the computational graph.
Is there a succinct way to do this?
isn’t it the same as torch.split(x, int_list). You can do int_list=indices.tolist() if you have a reason to feed in a tensor of integers.
@googlebot Thanks. From what I could tell, torch.split accepts chunk sizes rather than indices. For example to split
x at indices 5, 10, and 15 I think you would do
torch.split(x,[5,5,5]). So I just ended up doing
chunk_sizes = [len(x) for x in np.array_split(predicted_ids,indices)]
desired_output = torch.split(predicted_ids,chunk_sizes)
If there is a simpler way I would definitely be interested in hearing!
Ah, right. Then something like following would calculate sizes
>[2, 1, 97]
i.e. append total size, prepend zero if you have split points instead of indexes, take deltas.
Of course, thanks! Also saw this solution (essentially the same) just using indexing.