Hi everyone,
I want to limit the amount of past tokens the attention mechanism can attend to. I want to use the PyTorch Transformer implementation and I was thinking that one way to achieve that it is by modifying the tgt_mask to take that into account. However, I’m having some doubts:
- Should I also modify the tgt_key_padding_mask by telling that there is padding in the positions I don’t want it to attend?
- Moreover could I just modify tgt_key_padding_mask?
This is the code that I have for creating the tgt_mask.
def create_variable_window_mask(
size, window_size, dtype=torch.float32, device=torch.device("cpu")
):
"""
Creates a mask for the target sequence with a variable window size.
Args:
size (int): The size of the target sequence.
window_size (int): The size of the window to focus on the last X tokens.
Returns:
torch.Tensor: The generated mask.
"""
mask = torch.full((size, size), float("-inf"), dtype=dtype, device=device)
for i in range(size):
if window_size < size:
start = max(0, i - window_size)
mask[i, start : i + 1] = 0
else:
mask[i, : i + 1] = 0
return mask
Thank so much in advance!