Hi all,
Looking for the exact functionality provided by torch.tensor_split in v1.8 without having to upgrade to 1.8
predicted_ids = torch.tensor([100, 100, 31090, 100, 100, 100, 31090])
indices = torch.tensor([2,6])
np.array_split(predicted_ids,indices)
# returns [tensor([100, 100]), tensor([31090, 100, 100, 100]), tensor([31090])]
I would like the last line of the above, except via pytorch operations so I can avoid breaking the computational graph.
Is there a succinct way to do this?
isn’t it the same as torch.split(x, int_list). You can do int_list=indices.tolist() if you have a reason to feed in a tensor of integers.
@googlebot Thanks. From what I could tell, torch.split accepts chunk sizes rather than indices. For example to split x
at indices 5, 10, and 15 I think you would do torch.split(x,[5,5,5])
. So I just ended up doing
chunk_sizes = [len(x) for x in np.array_split(predicted_ids,indices)]
desired_output = torch.split(predicted_ids,chunk_sizes)
If there is a simpler way I would definitely be interested in hearing!
Ah, right. Then something like following would calculate sizes
np.diff(np.asarray([0]+[2,3]+[100])).tolist()
>[2, 1, 97]
i.e. append total size, prepend zero if you have split points instead of indexes, take deltas.
Of course, thanks! Also saw this solution (essentially the same) just using indexing.