Extracting differences from Attention Masks and Paddings, in order to get logit differences

For better or worse, I have the following scenario for a custom Transformer based model:

  • logits [batch, sequence, features]
  • ideals [batch, sequence, features]

Then I have the attention masks (which I use to test predictability of the model)

  • attn_masks [batch, sequence]

And finally I have the padding masks (which are pretty standard, padding what is not part of the sequence)

  • paddings [batch, sequence]

I’ve managed to write a very ugly, slow, monstrosity which does two things:

  1. it calculates the R2 score by looking at the values from the start of the attention mask, all the way to the end of the sequence
  2. it calculates an F1 score, by extracting the values from the start of the attention mask, up to the start of the padding (and binarizes it)

I accept full responsibility the code may be buggy.
The main logic is the following (I use pytorch-lighting so the data arrive in outputs I then gather the data and concatenate it into single tensors):

i = 0
for attn_mask, pad_mask in zip(torch.unbind(attns), torch.unbind(pads)):
    idx  = (attn_mask == True).nonzero(as_tuple=False)
    end = (pad_mask == True).nonzero(as_tuple=False)

    if len(end) > 0:
        z = torch.min(end).item()
        z = self.max_seq

    if len(idx) > 0:
        k = torch.min(idx).item()
        t_x = ideals[i, k:]
        t_y = logits[i, k:]

        masked_x = ideals[i, k - 1:z]
        masked_x_hat = ideals[i, k - 1:z]

    i += 1

I then use t_y and t_x to calculate an R2 score, and masked_x and masked_x_hat to binarise and calculate F1 score on a specific feature.

The problem is:

  1. i is the batch (I know its horrible) and the loop above unrolls for each and every single entry in the test batch.
  2. the way idx and end are computer are per-sample in the batch. Same for z.

I suspect there is a way to vectorize this and speed it up considerably.
Playing around with dummy data, I managed to get as far as below:

Assume that it’s a 2 item batch, and a 5-length sequence.

pads = torch.Tensor([[False, False, False, False, True], [False, False, False, False, False]])
attns =  torch.Tensor([[False, False, True, True, True], [False, False, False, True, False]])
vals  = torch.rand(2, 5)
diffs = torch.eq(attns, pads)
(diffs == False).nonzero(as_tuple=False)

Basically, in this example, for the R2 Score I want the values starting from the first True to the end of the sequence and for the F1 Score, I want the values starting from the first True in attention, up to the first True in the padding (or the end of the sequence).

I have no clue what built-in torch methods exist to achieve this without doing the loop. The data is in the GPU if it matters. Any help is much appreciated!