Split sequence into blocks according to its content

Thanks a lot for your very precise answer !

I modified it a bit in order not to use Numpy and it works great:

def get_split_wo_numpy(x):
    d = x[:,-1].to(torch.int32)
    _, counts = torch.unique(d, return_counts=True)
    return torch.split(x, counts.tolist())

# create tensor
N = 50
c0 = torch.randn(N)
c1 = torch.arange(N) + 0.1 * torch.randn(N)
x = torch.stack((c0, c1), dim=1)

print(*get_split_wo_numpy(x), sep='\n')
> tensor([[-0.3783,  0.0240]])
tensor([[0.5671, 1.0883],
        [0.1693, 1.9510]])
tensor([[0.4059, 2.9808]])
tensor([[-0.8332,  3.9922]])
tensor([[-1.7366,  4.8935]])
tensor([[0.1554, 5.9210]])
tensor([[-0.5020,  7.0563]])
tensor([[0.8593, 8.1128],
        [0.6790, 8.8930]])
tensor([[-1.3425,  9.9942]])
tensor([[ 0.4977, 11.1460]])
tensor([[-0.6046, 12.0957]])
tensor([[ 0.0482, 13.0584]])
tensor([[-3.5339, 14.0837],
        [ 0.3017, 14.8644]])
tensor([[-1.0470, 15.8537]])
tensor([[ 0.5970, 17.0868],
        [ 1.1420, 18.0000]])
tensor([[ 0.2047, 18.9484]])
tensor([[ 0.6992, 19.9839]])
tensor([[-0.7491, 21.0039]])
tensor([[-1.5463, 22.0174]])
tensor([[-1.1534, 23.0604],
        [-1.6386, 23.9969]])
tensor([[-0.6788, 24.9266]])
tensor([[-0.6330, 25.9962]])
tensor([[ 0.8385, 27.0327],
        [-0.5460, 27.9246]])
tensor([[ 1.0593, 29.0393]])
tensor([[ 0.8780, 30.1633]])
tensor([[ 0.5493, 31.0493],
        [ 0.3947, 31.9193]])
tensor([[ 0.1004, 32.8814]])
tensor([[-1.3328, 34.1098]])
tensor([[ 0.5115, 35.0216],
        [-0.1866, 35.9086]])
tensor([[-0.5092, 37.1763],
        [ 0.6438, 37.8571]])
tensor([[ 0.3718, 38.8968]])
tensor([[-1.6322, 40.0297],
        [-0.7793, 40.9798]])
tensor([[ 0.2476, 41.9675]])
tensor([[ 0.4328, 43.0768]])
tensor([[ 0.8585, 44.1284],
        [ 0.1131, 44.8837]])
tensor([[-0.2042, 46.1367]])
tensor([[ 0.7368, 47.0504],
        [-0.1325, 47.9358]])
tensor([[-1.1966, 48.9396]])

Moreover, assuming that the sequence is sorted, one can replace torch.unique with torch.unique_consecutive. These minor modifications enable to increase a bit the speed of your original function:

%timeit get_split(x)
%timeit get_split_wo_numpy(x)
%timeit get_split_wo_numpy_consecutive(x)
1000 loops, best of 3: 707 µs per loop
1000 loops, best of 3: 565 µs per loop
1000 loops, best of 3: 405 µs per loop

The wrong thing with my for loop is that it assumed that for all i there were an event whose timestamp is between i and i+1, which is not true in the general case.

Thanks again for your solution !

Best,

Alain