I have a function that takes a 2D tensor
output_ids and a 1D tensor
ignore_span. Using a sliding window approach it creates a mask equal to
penalty wherever the
ignore_span appears in
The function looks as below (the commented line will become apparent later):
import torch batch_size, sequence_length, vocab_size = (4,8,12) #(32, 1024, 2048) #(4,12,8) logits = torch.randint(0,2, size=(batch_size, sequence_length, vocab_size)) ignore_span = torch.tensor([0, 1, 2]) output_ids = torch.argmax(logits, dim=-1) # (batch_size, sequence_length) def penalty_mask_loop(output_ids, ignore_span, penalty=5): """ Slow, looped sliding window approach. """ mask = torch.zeros_like(output_ids) for i in range(output_ids.size(0)): for j in range(output_ids.size(1) - len(ignore_span) + 1): if torch.all(output_ids[i, j:j+len(ignore_span)] == ignore_span): mask[i, j:j+len(ignore_span)] = penalty #mask[i, j] = penalty return mask
I have developed a version that uses tensor operations to avoid the loops and it is much much faster, however the resulting mask is equivalent to the commented line above i.e. it only masks the first index of the span.
@torch.no_grad() def penalty_mask_tensor_ops(output_ids, ignore_span, penalty = 5): """ Faster tensor-op variant. """ mask = torch.ones_like(output_ids) # make all len(ignore_span) sized views of output_ids to simulate a sliding window slices = output_ids.unfold(1, len(ignore_span), 1) # shape (batch_size, sequence_length-len(ignore_span)+1, len(ignore_span)) matches = torch.all(slices == ignore_span, dim=2) # shape (batch_size, sequence_length-len(ignore_span)+1) i.e. along the dimension of the length of the span mask[:, :matches.size(1)] *= (penalty+1) ** matches return mask - 1 loopmask = penalty_mask_loop(output_ids, ignore_span) tensmask = penalty_mask_tensor_ops(output_ids, ignore_span)
Any ideas how I can get this to efficiently mask each occurrence of the full span rather than just the first index of each discovered span match?
Thank you in advance!