Transformer Mask Doesn't Do Anything

I’m trying to train a Transformer Seq2Seq model using nn.Transformer class. I believe I am implementing it wrong, since when I train it, it seems to fit too fast, and during inference it repeats itself often. This seems like a masking issue in the decoder, and when I remove the target mask, the training performance is the same. This leads me to believe I am doing the target masking wrong. Here is my model code:

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, input_dim, heads, feedforward_dim, encoder_layers, decoder_layers, sos_token, eos_token, pad_token, max_len=200, dropout=0.5, device=(torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))):
        super(TransformerModel, self).__init__()
        self.target_mask = None
        self.embedding = nn.Embedding(vocab_size, input_dim, padding_idx=pad_token)
        self.pos_embedding = nn.Embedding(max_len, input_dim, padding_idx=pad_token)
        self.transformer = nn.Transformer(d_model=input_dim, nhead=heads, num_encoder_layers=encoder_layers, num_decoder_layers=decoder_layers, dim_feedforward=feedforward_dim, dropout=dropout)
        self.out = nn.Sequential(nn.Linear(input_dim, feedforward_dim), nn.ReLU(), nn.Linear(feedforward_dim, vocab_size))

        self.device = device
        self.max_len = max_len
        self.sos_token = sos_token
        self.eos_token = eos_token


    def init_weights(self): # Initialize all weights to be uniformly distributed between -initrange and initrange
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def generate_square_subsequent_mask(self, size): # Generate mask covering the top right triangle of a matrix
        mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src, tgt):
        # src: (Max source seq len, batch size, 1)
        # tgt: (Max target seq len, batch size, 1)

        # Embed source and target with normal and positional embeddings
        embedded_src = self.embedding(src) + self.pos_embedding(torch.arange(0, src.shape[1]).to(self.device).unsqueeze(0).repeat(src.shape[0], 1))
        # Generate target mask
        target_mask = self.generate_square_subsequent_mask(size=tgt.shape[0]).to(self.device) # Create target mask
        embedded_tgt = self.embedding(tgt) + self.pos_embedding(torch.arange(0, tgt.shape[1]).to(self.device).unsqueeze(0).repeat(tgt.shape[0], 1))
        # Feed through model
        outputs = self.transformer(src=embedded_src, tgt=embedded_tgt, tgt_mask=target_mask)
        outputs = F.log_softmax(self.out(outputs), dim=-1)
        return outputs
2 Likes

Are you saying that the model is giving you repeated tokens after training?

I figured out the problem, I was not properly inserting SOS and EOS tokens, so even with proper masking it was able to copy straight from the given target.

3 Likes

Is there any chance i could see your code, i am facing the same issues as well… :frowning: . My model seems to be repeating the previous word over and over again [ <SOS> , CAR, CAR, CAR, .... ]. The masking of tgt doesn’t seem to be doing anything at all.

Also, in terms of adding the SOS and EOS, how did you shift the sentence to the left for the teacher forcing component ?

I don’t have the exact code anymore, I rewrote it all but for the target you feed the model, you should have an SOS token at the beginning and no EOS token at the end. For the target you use in the loss, there should be no SOS token, only an EOS token at the end.

Here is my full implementation of a transformer that uses word embeddings:

class EmbeddingTransformerSeq2Seq(nn.Module):
    '''A Seq2Seq Transformer which embeds inputs and outputs distributions over the output vocab\n
    Init Inputs:
        input_size (int): The size of embeddings in the network
        input_vocab (vocab): The input vocab
        target_vocab (vocab): The target vocab
        num_heads (int): The number of heads in both the encoder and decoder
        num_encoder_layers (int): The number of layers in the transformer encoder
        num_decoder_layers (int): The number of layers in the transformer decoder
        forward_expansion (int): The factor of expansion in the elementwise feedforward layer
        dropout (float): The amount of dropout
        max_len (int): The max target length used when a target is not provided
        device (torch.device): The device that the network will run on
    Inputs:
        src (Tensor): The input sequence of shape (src length, batch size)
        trg (Tensor) [default=None]: The target sequence of shape (trg length, batch size)
    Returns:
        output (Tensor): The return sequence of shape (target length, batch size, target tokens)'''
    def __init__(self, input_size, input_vocab, target_vocab, num_heads, num_encoder_layers, num_decoder_layers, forward_expansion, dropout=0.1, max_len=50, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
        super().__init__()
        self.hyperparameters = locals()
        self.device = device
        self.input_vocab = input_vocab
        self.target_vocab = target_vocab
        self.max_len = max_len
        self.src_embedding = nn.Embedding(input_vocab.num_words, input_size)
        self.src_positional_embedding = nn.Embedding(max_len, input_size)
        self.trg_embedding = nn.Embedding(target_vocab.num_words, input_size)
        self.trg_positional_embedding = nn.Embedding(max_len, input_size)
        self.transformer = nn.Transformer(d_model=input_size, dim_feedforward=input_size * forward_expansion, nhead=num_heads, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.fc_out = nn.Linear(input_size, target_vocab.num_words)

    def create_pad_mask(self, idx_seq, pad_idx):
        # idx_seq shape: (seq len, batch size)
        mask = idx_seq.transpose(0, 1) == pad_idx
        # mask shape: (batch size, seq len) <- PyTorch transformer wants this shape for mask
        return mask

    def forward(self, src, trg=None):
        src_len, batch_size = src.shape
        trg_len = trg.shape[0] if trg is not None else 1

        # Handle target given/autoregressive
        if trg is None:
            assert batch_size == 1, "In autoregressive mode, the batch size must be 1"
            autoregressive = True
            trg = torch.full((1, batch_size), fill_value=self.target_vocab.SOS_token, dtype=torch.long, device=self.device)
            final_out = torch.zeros((self.max_len, batch_size, self.target_vocab.num_words), device=self.device) # To hold the distributions
        else:
            autoregressive = False
            if trg[0][0].item() != self.target_vocab.SOS_token: # Ensure there is an SOS token at the start of the trg, add if there isn't
                trg = torch.cat((torch.full((1, batch_size), fill_value=self.target_vocab.SOS_token, dtype=torch.long, device=self.device), trg), dim=0)
                trg_len += 1
            if trg[-1][0].item() == self.target_vocab.EOS_token: # Ensure there is no EOS token at the end of the trg, remove if there is
                trg = trg[:-1]
                trg_len -= 1
        # Get source pad mask
        src_pad_mask = self.create_pad_mask(src, self.input_vocab.PAD_token)
        # Embed src
        src_positions = torch.arange(0, src_len, device=self.device).unsqueeze(1).expand(src_len, batch_size)
        src_embed = self.dropout(self.src_embedding(src) + self.src_positional_embedding(src_positions))

        for i in range(self.max_len if autoregressive else 1):
            # Get target pad mask
            trg_pad_mask = self.create_pad_mask(trg, self.target_vocab.PAD_token)

            # Embed target
            trg_positions = torch.arange(0, trg_len, device=self.device).unsqueeze(1).expand(trg_len, batch_size)
            trg_embed = self.dropout(self.trg_embedding(trg) + self.trg_positional_embedding(trg_positions))

            # Get target subsequent mask
            trg_subsequent_mask = self.transformer.generate_square_subsequent_mask(trg_len).to(self.device)

            # Training, just a single forward pass is needed
            out = self.transformer(src=src_embed, tgt=trg_embed, src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=trg_pad_mask, tgt_mask=trg_subsequent_mask)
            out = self.fc_out(out)
            if not self.training: out = F.softmax(out, dim=-1)
            # out shape: (trg_len, batch size, target_num_words)

            if autoregressive:
                trg_len += 1
                final_out[i] = out[-1]
                trg = torch.cat((trg, torch.argmax(out[-1], dim=-1).unsqueeze(1)), dim=0)
                if all([any(trg[:, x] == self.target_vocab.EOS_token) for x in range(batch_size)]): # EOS was outputted in all batches
                    return final_out[:i + 1]
            else: final_out = out

        # out shape: (trg_len, batch size, target_num_words)
        return final_out
1 Like


Relevant

Thank you so much for taking the time to share this.

Was just reading a related blog, you mean what is depicted in this diagram?

I find it strange that the model requires SOS and EOS tokens for whatever is input into the encoder/decoder, especially since those tokens are converted into their embeddings anyways. I figured that since the model is ignoring the target masking component for the current embeddings so why should it treat the presence of SOS & EOS any different. How does it know to finally start using the target masks? It not like there is some explicit mechanism inside the transformer model that says “Hey we’re finally being fed SOS and EOS tokens, so lets get our act together and start using the target mask!”

Yes, that is a diagram of what looks like an RNN but the concepts should be the same. I don’t think you need to feed in SOS and EOS tokens to the encoder in the input, I’m only talking about the target.

The decoder uses the target mask, not the encoder. The encoder and the decoder are two seperate transformers. The target is fed into the decoder for teacher forcing to help train faster, but we need to make sure it can’t just copy the given target to the output so we use a mask to prevent it from looking at the tokens one word ahead. The SOS token offsets the given target from what should be outputted.

1 Like

I agree, I don’t think the SOS and EOS tokens are needed in the source sentence either.

Could you say a bit more about your training process? In regards to the auto-regressive part.

Also, if you don’t mind, how did your model end up performing?

My model handles the autoregressive part internally in the forward function. During training it doesn’t do any autogressive decoding. During testing it has to autoregressivly decode, but that is handled internally in my implementation so testing is very similar to training. As far as performance goes, in my tasks I couldn’t get it to work as well as a GRU Seq2Seq.

1 Like

would it be necessary to use the target_mask in the testing time?