Stacking variable length chunks of a tensor (vectorized solution)

Is there a function that chunks a tensor along a dimension with given chunk sizes and stacks the chunks into a single tensor (also pads the chunks since they can be variable sized)?

A = torch.randn(10, 512, 768)
sizes = torch.tensor([1, 3, 6])
assert sizes.sum() == A.size(0)

out = function(A, sizes, dim=0)

assert out.shape == (len(sizes), sizes.max(), 512, 768)

torch.split and pad_sequence should work:

A = torch.randn(10, 512, 768)
sizes = torch.tensor([1, 3, 6])
assert sizes.sum() == A.size(0)

out = torch.split(A, sizes.tolist(), 0)
out = torch.nn.utils.rnn.pad_sequence(out, batch_first=True)
assert out.shape == (len(sizes), sizes.max(), 512, 768)
1 Like