Optimizing sliding window function with Tensor operations

Hi all :slight_smile:

I have a function that takes a 2D tensor output_ids and a 1D tensor ignore_span. Using a sliding window approach it creates a mask equal to penalty wherever the ignore_span appears in output_ids

The function looks as below (the commented line will become apparent later):

import torch

batch_size, sequence_length, vocab_size = (4,8,12) #(32, 1024, 2048) #(4,12,8)
logits = torch.randint(0,2, size=(batch_size, sequence_length, vocab_size))
ignore_span = torch.tensor([0, 1, 2])
output_ids = torch.argmax(logits, dim=-1) # (batch_size, sequence_length)


def penalty_mask_loop(output_ids, ignore_span, penalty=5):
    """
    Slow, looped sliding window approach.
    """
    mask = torch.zeros_like(output_ids)
    for i in range(output_ids.size(0)):
                for j in range(output_ids.size(1) - len(ignore_span) + 1):
                    if torch.all(output_ids[i, j:j+len(ignore_span)] == ignore_span):
                        mask[i, j:j+len(ignore_span)] = penalty
                        #mask[i, j] = penalty
    return mask

I have developed a version that uses tensor operations to avoid the loops and it is much much faster, however the resulting mask is equivalent to the commented line above i.e. it only masks the first index of the span.

See below:

@torch.no_grad()
def penalty_mask_tensor_ops(output_ids, ignore_span, penalty = 5):
    """
    Faster tensor-op variant.
    """
    mask = torch.ones_like(output_ids)
    # make all len(ignore_span) sized views of output_ids to simulate a sliding window
    slices = output_ids.unfold(1, len(ignore_span), 1)  # shape (batch_size, sequence_length-len(ignore_span)+1, len(ignore_span))                                                     
    matches = torch.all(slices == ignore_span, dim=2)  # shape (batch_size, sequence_length-len(ignore_span)+1) i.e. along the dimension of the length of the span
    mask[:, :matches.size(1)] *= (penalty+1) ** matches 
    return mask - 1

loopmask = penalty_mask_loop(output_ids, ignore_span)
tensmask = penalty_mask_tensor_ops(output_ids, ignore_span)

Any ideas how I can get this to efficiently mask each occurrence of the full span rather than just the first index of each discovered span match?

Thank you in advance!