I have a batch of sequences(tensors). Each sequence is of the form <|sos|> … <|eos|> … <|eos|> . I want to truncate each sequence after the first token.

Thanks in advance.

This is not that easy a problem. One thing you can do use `arange`

to get a tensor of indices and then `where`

to mask the tensor of indices to be 2000 (or > seqlen). Then the minimum of these masked indices will be the first token you did not mask.

Then you can use these to get a slice of the tensor. Note that if you do work on the GPU, you will have a synchronization point because the slicing needs the index on the CPU.

Best regards

Thomas