For better or worse, I have the following scenario for a custom Transformer based model:
- logits [batch, sequence, features]
- ideals [batch, sequence, features]
Then I have the attention masks (which I use to test predictability of the model)
- attn_masks [batch, sequence]
And finally I have the padding masks (which are pretty standard, padding what is not part of the sequence)
- paddings [batch, sequence]
I’ve managed to write a very ugly, slow, monstrosity which does two things:
- it calculates the R2 score by looking at the values from the start of the attention mask, all the way to the end of the sequence
- it calculates an F1 score, by extracting the values from the start of the attention mask, up to the start of the padding (and binarizes it)
I accept full responsibility the code may be buggy.
The main logic is the following (I use pytorch-lighting so the data arrive in outputs
I then gather the data and concatenate it into single tensors):
i = 0
for attn_mask, pad_mask in zip(torch.unbind(attns), torch.unbind(pads)):
idx = (attn_mask == True).nonzero(as_tuple=False)
end = (pad_mask == True).nonzero(as_tuple=False)
if len(end) > 0:
z = torch.min(end).item()
else:
z = self.max_seq
if len(idx) > 0:
k = torch.min(idx).item()
t_x = ideals[i, k:]
t_y = logits[i, k:]
masked_x = ideals[i, k - 1:z]
masked_x_hat = ideals[i, k - 1:z]
i += 1
I then use t_y
and t_x
to calculate an R2 score, and masked_x
and masked_x_hat
to binarise and calculate F1 score on a specific feature.
The problem is:
-
i
is the batch (I know its horrible) and the loop above unrolls for each and every single entry in the test batch. - the way
idx
andend
are computer are per-sample in the batch. Same forz
.
I suspect there is a way to vectorize this and speed it up considerably.
Playing around with dummy data, I managed to get as far as below:
Assume that it’s a 2 item batch, and a 5-length sequence.
pads = torch.Tensor([[False, False, False, False, True], [False, False, False, False, False]])
attns = torch.Tensor([[False, False, True, True, True], [False, False, False, True, False]])
vals = torch.rand(2, 5)
diffs = torch.eq(attns, pads)
(diffs == False).nonzero(as_tuple=False)
Basically, in this example, for the R2 Score I want the values starting from the first True
to the end of the sequence and for the F1 Score, I want the values starting from the first True
in attention, up to the first True
in the padding (or the end of the sequence).
I have no clue what built-in torch
methods exist to achieve this without doing the loop. The data is in the GPU if it matters. Any help is much appreciated!