Is there a function that chunks a tensor along a dimension with given chunk sizes and stacks the chunks into a single tensor (also pads the chunks since they can be variable sized)?
A = torch.randn(10, 512, 768)
sizes = torch.tensor([1, 3, 6])
assert sizes.sum() == A.size(0)
out = function(A, sizes, dim=0)
assert out.shape == (len(sizes), sizes.max(), 512, 768)