A vectorized way to find mention spans?

Crossposted from here but narrows down the question.

Suppose I have a matrix of rows of integers containing BIO outputs (assume O = 0, B = 1, I = 2).

bio_outputs = torch.tensor([[0, 0, 1, 2, 2, 1, 2, 0, 1, 1],
                                #        1  1  1  2  2      3  4
                                 [1, 2, 2, 0, 0, 0, 1, 1, 0, 0]])
                                # 5  5  5             6  7

begin_positions = torch.nonzero(bio_outputs == 1)
end_positions = ...
assert(end_positions == torch.tensor([4, 6, 8, 9, 2, 6, 7])

The challenge is to compute end_positions. An end position can be either the last 2 (I) before an O (or the end of the tensor) or a new 1, or the start index itself. What I wish I could have is that end_position is a tensor of shape M (with M the number of detected 1s (Bs), in the example above 7). The reason for that is that I can perform a torch.cat on begin_positions to have the right start and end index for each mention.

I am using PyTorch 1.7.0 and cannot use experimental features like vmap.

Any ideas? Is this problems worth vectorizing?