Thanks a lot for your very precise answer !
I modified it a bit in order not to use Numpy and it works great:
def get_split_wo_numpy(x):
d = x[:,-1].to(torch.int32)
_, counts = torch.unique(d, return_counts=True)
return torch.split(x, counts.tolist())
# create tensor
N = 50
c0 = torch.randn(N)
c1 = torch.arange(N) + 0.1 * torch.randn(N)
x = torch.stack((c0, c1), dim=1)
print(*get_split_wo_numpy(x), sep='\n')
> tensor([[-0.3783, 0.0240]])
tensor([[0.5671, 1.0883],
[0.1693, 1.9510]])
tensor([[0.4059, 2.9808]])
tensor([[-0.8332, 3.9922]])
tensor([[-1.7366, 4.8935]])
tensor([[0.1554, 5.9210]])
tensor([[-0.5020, 7.0563]])
tensor([[0.8593, 8.1128],
[0.6790, 8.8930]])
tensor([[-1.3425, 9.9942]])
tensor([[ 0.4977, 11.1460]])
tensor([[-0.6046, 12.0957]])
tensor([[ 0.0482, 13.0584]])
tensor([[-3.5339, 14.0837],
[ 0.3017, 14.8644]])
tensor([[-1.0470, 15.8537]])
tensor([[ 0.5970, 17.0868],
[ 1.1420, 18.0000]])
tensor([[ 0.2047, 18.9484]])
tensor([[ 0.6992, 19.9839]])
tensor([[-0.7491, 21.0039]])
tensor([[-1.5463, 22.0174]])
tensor([[-1.1534, 23.0604],
[-1.6386, 23.9969]])
tensor([[-0.6788, 24.9266]])
tensor([[-0.6330, 25.9962]])
tensor([[ 0.8385, 27.0327],
[-0.5460, 27.9246]])
tensor([[ 1.0593, 29.0393]])
tensor([[ 0.8780, 30.1633]])
tensor([[ 0.5493, 31.0493],
[ 0.3947, 31.9193]])
tensor([[ 0.1004, 32.8814]])
tensor([[-1.3328, 34.1098]])
tensor([[ 0.5115, 35.0216],
[-0.1866, 35.9086]])
tensor([[-0.5092, 37.1763],
[ 0.6438, 37.8571]])
tensor([[ 0.3718, 38.8968]])
tensor([[-1.6322, 40.0297],
[-0.7793, 40.9798]])
tensor([[ 0.2476, 41.9675]])
tensor([[ 0.4328, 43.0768]])
tensor([[ 0.8585, 44.1284],
[ 0.1131, 44.8837]])
tensor([[-0.2042, 46.1367]])
tensor([[ 0.7368, 47.0504],
[-0.1325, 47.9358]])
tensor([[-1.1966, 48.9396]])
Moreover, assuming that the sequence is sorted, one can replace torch.unique
with torch.unique_consecutive
. These minor modifications enable to increase a bit the speed of your original function:
%timeit get_split(x)
%timeit get_split_wo_numpy(x)
%timeit get_split_wo_numpy_consecutive(x)
1000 loops, best of 3: 707 µs per loop
1000 loops, best of 3: 565 µs per loop
1000 loops, best of 3: 405 µs per loop
The wrong thing with my for
loop is that it assumed that for all i there were an event whose timestamp is between i and i+1, which is not true in the general case.
Thanks again for your solution !
Best,
Alain