Crossposted from here but narrows down the question.
Suppose I have a matrix of rows of integers containing BIO outputs (assume O = 0, B = 1, I = 2).
bio_outputs = torch.tensor([[0, 0, 1, 2, 2, 1, 2, 0, 1, 1],
# 1 1 1 2 2 3 4
[1, 2, 2, 0, 0, 0, 1, 1, 0, 0]])
# 5 5 5 6 7
begin_positions = torch.nonzero(bio_outputs == 1)
end_positions = ...
assert(end_positions == torch.tensor([4, 6, 8, 9, 2, 6, 7])
The challenge is to compute end_positions
. An end position can be either the last 2 (I) before an O (or the end of the tensor) or a new 1, or the start index itself. What I wish I could have is that end_position is a tensor of shape M (with M the number of detected 1s (Bs), in the example above 7). The reason for that is that I can perform a torch.cat
on begin_positions to have the right start and end index for each mention.
I am using PyTorch 1.7.0 and cannot use experimental features like vmap
.
Any ideas? Is this problems worth vectorizing?