You might be interested in nested tensors which are special subclass of tensor designed to work on batches of data where the batches are jagged/have varying length.
torch.nested.nested_tensor_from_jagged(values=x, offsets=stride)
# NestedTensor(size=(3, j1), offsets=tensor([0, 1, 3, 6]), contiguous=True)
There’s a tutorial here:
https://pytorch.org/tutorials/prototype/nestedtensor.html
(Note that some parts of it are out of date, any call to torch.nested.nested_tensor should be passed a layout=torch.jagged
argument; The part about implementing MHA has been updated relatively recently though)
NestedTensors unfortunately don’t work with vmap yet, so you’ll need to explicitly write out the batch dim in your program. (We have plans to support this soon though)