I have a batched tensor of size [k,max_seq_len,embedding_dim] (batch size is k) which represents a concatenation two sequences, the first sequence is of size n_i and the second is of size m_i and I want to split this into two tensors where the first contains only the first sequence and the second has the second sequence, where the padding is according to the max n_i / m_i.

For example:

I can do this with loops, using torch.split, but I am interested in a solution without a loop on the batch dimension.

Thanks!