Windowed attention

Hi everyone,

I want to limit the amount of past tokens the attention mechanism can attend to. I want to use the PyTorch Transformer implementation and I was thinking that one way to achieve that it is by modifying the tgt_mask to take that into account. However, I’m having some doubts:

  • Should I also modify the tgt_key_padding_mask by telling that there is padding in the positions I don’t want it to attend?
  • Moreover could I just modify tgt_key_padding_mask?

This is the code that I have for creating the tgt_mask.

def create_variable_window_mask(
    size, window_size, dtype=torch.float32, device=torch.device("cpu")
):
    """
    Creates a mask for the target sequence with a variable window size.
    
    Args:
    size (int): The size of the target sequence.
    window_size (int): The size of the window to focus on the last X tokens.
    
    Returns:
    torch.Tensor: The generated mask.
    """
    mask = torch.full((size, size), float("-inf"), dtype=dtype, device=device)
    for i in range(size):
        if window_size < size:
            start = max(0, i - window_size)
            mask[i, start : i + 1] = 0
        else:
            mask[i, : i + 1] = 0
    return mask

Thank so much in advance!