If you could use numpy, then this method should work:
def get_splits(x):
bins = np.arange(0, np.ceil(x[:,1].max())+1)
d = torch.from_numpy(np.digitize(x.numpy()[:, 1], bins))
_, counts = torch.unique(d, return_counts=True)
return torch.split(x, counts.tolist())
# Create tensor
c0 = torch.arange(50).view(50, 1).float()
c1 = torch.arange(50).view(50, 1) + torch.randn(50, 1) * 0.1
x = torch.cat((c0, c1), dim=1)
print(get_splits(x))
> (tensor([[ 0.0000, -0.0317]]),
tensor([[1.0000, 1.0548]]),
tensor([[2.0000, 2.1610]]),
tensor([[3.0000, 3.0097]]),
tensor([[4.0000, 4.0108]]),
tensor([[5.0000, 5.0012]]),
tensor([[6.0000, 6.0031],
[7.0000, 6.9909]]),
tensor([[8.0000, 8.2395],
[9.0000, 8.8865]]),
tensor([[10.0000, 9.9980]]),
tensor([[11.0000, 11.0808],
[12.0000, 11.9696]]),
tensor([[13.0000, 12.9840]]),
tensor([[14.0000, 14.0748]]),
tensor([[15.0000, 15.0449]]),
tensor([[16.0000, 16.0221]]),
tensor([[17.0000, 17.2724],
[18.0000, 17.9412]]),
tensor([[19.0000, 19.0168]]),
tensor([[20.0000, 20.0405]]),
tensor([[21.0000, 21.0120],
[22.0000, 21.8588]]),
tensor([[23.0000, 22.9400]]),
tensor([[24.0000, 24.0344],
[25.0000, 24.8939]]),
tensor([[26.0000, 26.0936],
[27.0000, 26.9038]]),
tensor([[28.0000, 27.8985]]),
tensor([[29.0000, 28.9492]]),
tensor([[30.0000, 29.8341]]),
tensor([[31.0000, 30.9124]]),
tensor([[32.0000, 31.9069]]),
tensor([[33.0000, 33.0332],
[34.0000, 33.8976]]),
tensor([[35.0000, 34.8466]]),
tensor([[36.0000, 35.9875]]),
tensor([[37.0000, 36.9951]]),
tensor([[38.0000, 37.9140]]),
tensor([[39.0000, 39.1045]]),
tensor([[40.0000, 40.0363]]),
tensor([[41.0000, 41.0674],
[42.0000, 41.7366]]),
tensor([[43.0000, 42.9243]]),
tensor([[44.0000, 44.1054],
[45.0000, 44.9093]]),
tensor([[46.0000, 45.9525]]),
tensor([[47.0000, 46.9660]]),
tensor([[48.0000, 47.9752]]),
tensor([[49.0000, 48.9403]]))
%timeit get_split(x)
returns
211 µs ± 10.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
,
while your loop would take
1.35 ms ± 104 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
.
Note however, that I don’t get valid results using your loop:
def reference(seq):
blocks = []
block = []
limit = 1
for event in seq:
if event[-1] >= limit:
blocks.append(torch.stack(block, dim=0))
limit += 1
block = [event]
else:
block.append(event)
blocks.append(torch.stack(block, dim=0))
return blocks
print(reference(x))
> [tensor([[ 0.0000, -0.0317]]),
tensor([[1.0000, 1.0548]]),
tensor([[2.0000, 2.1610]]),
tensor([[3.0000, 3.0097]]),
tensor([[4.0000, 4.0108]]),
tensor([[5.0000, 5.0012]]),
tensor([[6.0000, 6.0031],
[7.0000, 6.9909]]),
tensor([[8.0000, 8.2395]]),
tensor([[9.0000, 8.8865]]),
tensor([[10.0000, 9.9980]]),
tensor([[11.0000, 11.0808]]),
tensor([[12.0000, 11.9696]]),
tensor([[13.0000, 12.9840]]),
tensor([[14.0000, 14.0748]]),
tensor([[15.0000, 15.0449]]),
tensor([[16.0000, 16.0221]]),
tensor([[17.0000, 17.2724]]),
tensor([[18.0000, 17.9412]]),
tensor([[19.0000, 19.0168]]),
tensor([[20.0000, 20.0405]]),
tensor([[21.0000, 21.0120]]),
tensor([[22.0000, 21.8588]]),
tensor([[23.0000, 22.9400]]),
tensor([[24.0000, 24.0344]]),
tensor([[25.0000, 24.8939]]),
tensor([[26.0000, 26.0936]]),
tensor([[27.0000, 26.9038]]),
tensor([[28.0000, 27.8985]]),
tensor([[29.0000, 28.9492]]),
tensor([[30.0000, 29.8341]]),
tensor([[31.0000, 30.9124]]),
tensor([[32.0000, 31.9069]]),
tensor([[33.0000, 33.0332]]),
tensor([[34.0000, 33.8976]]),
tensor([[35.0000, 34.8466]]),
tensor([[36.0000, 35.9875]]),
tensor([[37.0000, 36.9951]]),
tensor([[38.0000, 37.9140]]),
tensor([[39.0000, 39.1045]]),
tensor([[40.0000, 40.0363]]),
tensor([[41.0000, 41.0674]]),
tensor([[42.0000, 41.7366]]),
tensor([[43.0000, 42.9243]]),
tensor([[44.0000, 44.1054]]),
tensor([[45.0000, 44.9093]]),
tensor([[46.0000, 45.9525]]),
tensor([[47.0000, 46.9660]]),
tensor([[48.0000, 47.9752]]),
tensor([[49.0000, 48.9403]])]