Masking in PyTorch Transformer

Hello everyone,

I’ve been looking for some guide on how to correctly use the PyTorch transformer modules with its masking etc. I have to admit, I am still a little bit lost and would love some guidance.

I am trying to write a GPT-like model that will be trained in unsupervised manner on variable-length sequences to predict the next token in the sequence. Therefore, I have prepared my data such that I can generate a minibatch of

  • input: a batch of sequences of feature vectors, zero-padded to have the same length for each sequence
  • target: a batch of sequences as above, but the sequence is shifted one sample to the right (if input sequence is [1,2,3,4], target is [2,3,4,5])
  • padding mask: a batch of padding masks with 0 where there is a value, and 1 when there is a padding

Now, my model is a bit more complicated because my data consist of different types of inputs, but I am using a decoder-only architecture, therefore, the model is defined as

class TransformerModel(nn.Module):
    def __init__(self, dataset, num_features, d_model, nhead, num_layers, dim_feedforward, dropout):
        super(TransformerModel, self).__init__()
        
        self.embedding = nn.Linear(num_features, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        m = self.generate_square_subsequent_mask()
        self.mask = m
        
        self.transformer_layers = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer = nn.TransformerDecoder(self.transformer_layers, num_layers)

        self.mean = nn.Linear(d_model, dataset.continuous_length)
        self.var = nn.Sequential(
            nn.Linear(d_model, dataset.continuous_length),
            nn.Softplus(),
        )
        self.binary_model = nn.Linear(d_model, dataset.binary_length)
        self.onehot_model = nn.Linear(d_model, dataset.onehot_length)
        
    def generate_square_subsequent_mask(self, size=200): # Generate mask covering the top right triangle of a matrix
        mask = torch.triu(torch.full((size, size), float('-inf'), device=device), diagonal=1)
        return mask
    def forward(self, src, padding_mask=None, causality_mask=None):
    
        # process through the model
        src = self.embedding(src)
        src = self.positional_encoding(src)
        mask_size = src.shape[1]
        # print(mask_size)
        m = self.generate_square_subsequent_mask(mask_size)
        x = self.transformer(
            src, src,                                                                 # target and memory are the same
            tgt_mask=m, memory_mask=m,                                                # triangular masks so that we do not attend to the future tokens
            tgt_key_padding_mask=padding_mask, memory_key_padding_mask=padding_mask,  # padding mask, so that we are not training on padded parts of the sequences
            norm_first=True                                                           # nowadays, layer norm is used before the attention block
        )
        
        # process the outpus
        c_mean = self.mean(x)
        c_var = self.var(x)
        b = torch.sigmoid(self.binary_model(x))
        oh = self.onehot_model(x) # should be raw logits
        
        return c_mean, c_var, b, oh

The model has 4 outputs, but that is not the important part. The important part is the masks, since I am not sure I am using them right.

In GPT-like architecture, I cannot look at future tokens, I am using the triangular matrix which has -inf above the diagonal, zeros otherwise, for this purpose. My question is, do I use the same masks both as target_mask and memory_mask? Since there is no encoder, the input for target and memory is the same, and therefore I would suppose both need to have the same mask as well - is that correct? Same goes for the padding masks - those are the same, since target and memory are the same as well.

Thanks for any suggestions, I just want to make sure I am using the framework the right way (do not want to implement it manually as most people do in their tutorials…).