Split sequence into blocks according to its content

I have a sequence of L events stored in a tensor of shape L x W, where for each event its last feature is a timestamp which indicates when the event appears (it is actually a musical sequence in which events are notes, and the timestamp indicates when notes are played).

What I’d like to do is then to split this sequence into blocks according to their timestamp.

More concretely, let’s assume my sequence is:

tensor([
    [n1, 0],
    [n2, 0.25],
    [n3, 0.75],
    [n4, 1],
    [n5, 1.5],
    [n6, 2.1],
    ...
])

What I want is to split the sequence into blocks so that each block contains the events appearing during a 1 second window, i.e. in the previous example return:

[
    tensor([
        [n1, 0],
        [n2, 0.25],
        [n3, 0.75],
    ]),
    tensor([
        [n4, 1],
        [n5, 1.5]
    ]),
    tensor([
        [n6, 2.1]
    ]),
    ...
]

The only solution I see is to “naively” use Python lists and for loops, so I was wondering if many exists a magic function that enables to do this in a fast way. I was thinking about using torch.split which is much faster and might do the job but I can’t find how to use it.

benchmark comparison between for loop and torch.split:

$ python3 -m timeit -s 'import torch; seq=torch.tensor([0, 0.25, 0.75, 1, 1.5, 2.1, 3]).reshape(-1,1)' '
blocks = []
block = []
limit = 1
for event in seq:
    if event[-1] >= limit:
        blocks.append(torch.stack(block, dim=0))
        limit += 1
        block = [event]
    else:
        block.append(event)
blocks.append(torch.stack(block, dim=0))
'
1000 loops, best of 3: 418 usec per loop
$ python3 -m timeit -s 'import torch; seq=torch.tensor([0, 0.25, 0.75, 1, 1.5, 2.1, 3]).reshape(-1,1)' 'torch.split(seq, [3, 2, 1, 1])'
100000 loops, best of 3: 12.2 usec per loop

Is there an efficient way to solve this problem?

Best,

Alain

If you could use numpy, then this method should work:

def get_splits(x):
    bins = np.arange(0, np.ceil(x[:,1].max())+1)
    d = torch.from_numpy(np.digitize(x.numpy()[:, 1], bins))
    _, counts = torch.unique(d, return_counts=True)
    return torch.split(x, counts.tolist())

# Create tensor
c0 = torch.arange(50).view(50, 1).float()
c1 = torch.arange(50).view(50, 1) + torch.randn(50, 1) * 0.1
x = torch.cat((c0, c1), dim=1)

print(get_splits(x))
> (tensor([[ 0.0000, -0.0317]]),
 tensor([[1.0000, 1.0548]]),
 tensor([[2.0000, 2.1610]]),
 tensor([[3.0000, 3.0097]]),
 tensor([[4.0000, 4.0108]]),
 tensor([[5.0000, 5.0012]]),
 tensor([[6.0000, 6.0031],
         [7.0000, 6.9909]]),
 tensor([[8.0000, 8.2395],
         [9.0000, 8.8865]]),
 tensor([[10.0000,  9.9980]]),
 tensor([[11.0000, 11.0808],
         [12.0000, 11.9696]]),
 tensor([[13.0000, 12.9840]]),
 tensor([[14.0000, 14.0748]]),
 tensor([[15.0000, 15.0449]]),
 tensor([[16.0000, 16.0221]]),
 tensor([[17.0000, 17.2724],
         [18.0000, 17.9412]]),
 tensor([[19.0000, 19.0168]]),
 tensor([[20.0000, 20.0405]]),
 tensor([[21.0000, 21.0120],
         [22.0000, 21.8588]]),
 tensor([[23.0000, 22.9400]]),
 tensor([[24.0000, 24.0344],
         [25.0000, 24.8939]]),
 tensor([[26.0000, 26.0936],
         [27.0000, 26.9038]]),
 tensor([[28.0000, 27.8985]]),
 tensor([[29.0000, 28.9492]]),
 tensor([[30.0000, 29.8341]]),
 tensor([[31.0000, 30.9124]]),
 tensor([[32.0000, 31.9069]]),
 tensor([[33.0000, 33.0332],
         [34.0000, 33.8976]]),
 tensor([[35.0000, 34.8466]]),
 tensor([[36.0000, 35.9875]]),
 tensor([[37.0000, 36.9951]]),
 tensor([[38.0000, 37.9140]]),
 tensor([[39.0000, 39.1045]]),
 tensor([[40.0000, 40.0363]]),
 tensor([[41.0000, 41.0674],
         [42.0000, 41.7366]]),
 tensor([[43.0000, 42.9243]]),
 tensor([[44.0000, 44.1054],
         [45.0000, 44.9093]]),
 tensor([[46.0000, 45.9525]]),
 tensor([[47.0000, 46.9660]]),
 tensor([[48.0000, 47.9752]]),
 tensor([[49.0000, 48.9403]]))

%timeit get_split(x) returns
211 µs ± 10.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each),
while your loop would take
1.35 ms ± 104 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each).

Note however, that I don’t get valid results using your loop:

def reference(seq):
    blocks = []
    block = []
    limit = 1
    for event in seq:
        if event[-1] >= limit:
            blocks.append(torch.stack(block, dim=0))
            limit += 1
            block = [event]
        else:
            block.append(event)
    blocks.append(torch.stack(block, dim=0))
    return blocks

print(reference(x))
> [tensor([[ 0.0000, -0.0317]]),
 tensor([[1.0000, 1.0548]]),
 tensor([[2.0000, 2.1610]]),
 tensor([[3.0000, 3.0097]]),
 tensor([[4.0000, 4.0108]]),
 tensor([[5.0000, 5.0012]]),
 tensor([[6.0000, 6.0031],
         [7.0000, 6.9909]]),
 tensor([[8.0000, 8.2395]]),
 tensor([[9.0000, 8.8865]]),
 tensor([[10.0000,  9.9980]]),
 tensor([[11.0000, 11.0808]]),
 tensor([[12.0000, 11.9696]]),
 tensor([[13.0000, 12.9840]]),
 tensor([[14.0000, 14.0748]]),
 tensor([[15.0000, 15.0449]]),
 tensor([[16.0000, 16.0221]]),
 tensor([[17.0000, 17.2724]]),
 tensor([[18.0000, 17.9412]]),
 tensor([[19.0000, 19.0168]]),
 tensor([[20.0000, 20.0405]]),
 tensor([[21.0000, 21.0120]]),
 tensor([[22.0000, 21.8588]]),
 tensor([[23.0000, 22.9400]]),
 tensor([[24.0000, 24.0344]]),
 tensor([[25.0000, 24.8939]]),
 tensor([[26.0000, 26.0936]]),
 tensor([[27.0000, 26.9038]]),
 tensor([[28.0000, 27.8985]]),
 tensor([[29.0000, 28.9492]]),
 tensor([[30.0000, 29.8341]]),
 tensor([[31.0000, 30.9124]]),
 tensor([[32.0000, 31.9069]]),
 tensor([[33.0000, 33.0332]]),
 tensor([[34.0000, 33.8976]]),
 tensor([[35.0000, 34.8466]]),
 tensor([[36.0000, 35.9875]]),
 tensor([[37.0000, 36.9951]]),
 tensor([[38.0000, 37.9140]]),
 tensor([[39.0000, 39.1045]]),
 tensor([[40.0000, 40.0363]]),
 tensor([[41.0000, 41.0674]]),
 tensor([[42.0000, 41.7366]]),
 tensor([[43.0000, 42.9243]]),
 tensor([[44.0000, 44.1054]]),
 tensor([[45.0000, 44.9093]]),
 tensor([[46.0000, 45.9525]]),
 tensor([[47.0000, 46.9660]]),
 tensor([[48.0000, 47.9752]]),
 tensor([[49.0000, 48.9403]])]

Thanks a lot for your very precise answer !

I modified it a bit in order not to use Numpy and it works great:

def get_split_wo_numpy(x):
    d = x[:,-1].to(torch.int32)
    _, counts = torch.unique(d, return_counts=True)
    return torch.split(x, counts.tolist())

# create tensor
N = 50
c0 = torch.randn(N)
c1 = torch.arange(N) + 0.1 * torch.randn(N)
x = torch.stack((c0, c1), dim=1)

print(*get_split_wo_numpy(x), sep='\n')
> tensor([[-0.3783,  0.0240]])
tensor([[0.5671, 1.0883],
        [0.1693, 1.9510]])
tensor([[0.4059, 2.9808]])
tensor([[-0.8332,  3.9922]])
tensor([[-1.7366,  4.8935]])
tensor([[0.1554, 5.9210]])
tensor([[-0.5020,  7.0563]])
tensor([[0.8593, 8.1128],
        [0.6790, 8.8930]])
tensor([[-1.3425,  9.9942]])
tensor([[ 0.4977, 11.1460]])
tensor([[-0.6046, 12.0957]])
tensor([[ 0.0482, 13.0584]])
tensor([[-3.5339, 14.0837],
        [ 0.3017, 14.8644]])
tensor([[-1.0470, 15.8537]])
tensor([[ 0.5970, 17.0868],
        [ 1.1420, 18.0000]])
tensor([[ 0.2047, 18.9484]])
tensor([[ 0.6992, 19.9839]])
tensor([[-0.7491, 21.0039]])
tensor([[-1.5463, 22.0174]])
tensor([[-1.1534, 23.0604],
        [-1.6386, 23.9969]])
tensor([[-0.6788, 24.9266]])
tensor([[-0.6330, 25.9962]])
tensor([[ 0.8385, 27.0327],
        [-0.5460, 27.9246]])
tensor([[ 1.0593, 29.0393]])
tensor([[ 0.8780, 30.1633]])
tensor([[ 0.5493, 31.0493],
        [ 0.3947, 31.9193]])
tensor([[ 0.1004, 32.8814]])
tensor([[-1.3328, 34.1098]])
tensor([[ 0.5115, 35.0216],
        [-0.1866, 35.9086]])
tensor([[-0.5092, 37.1763],
        [ 0.6438, 37.8571]])
tensor([[ 0.3718, 38.8968]])
tensor([[-1.6322, 40.0297],
        [-0.7793, 40.9798]])
tensor([[ 0.2476, 41.9675]])
tensor([[ 0.4328, 43.0768]])
tensor([[ 0.8585, 44.1284],
        [ 0.1131, 44.8837]])
tensor([[-0.2042, 46.1367]])
tensor([[ 0.7368, 47.0504],
        [-0.1325, 47.9358]])
tensor([[-1.1966, 48.9396]])

Moreover, assuming that the sequence is sorted, one can replace torch.unique with torch.unique_consecutive. These minor modifications enable to increase a bit the speed of your original function:

%timeit get_split(x)
%timeit get_split_wo_numpy(x)
%timeit get_split_wo_numpy_consecutive(x)
1000 loops, best of 3: 707 µs per loop
1000 loops, best of 3: 565 µs per loop
1000 loops, best of 3: 405 µs per loop

The wrong thing with my for loop is that it assumed that for all i there were an event whose timestamp is between i and i+1, which is not true in the general case.

Thanks again for your solution !

Best,

Alain

Integer rounding was a clever trick! Thanks for sharing the improved approach :slight_smile: