Transformer Model only predicts Start or End Tokens

So I’ve been trying to build and train a Transformer Model from scratch for empathetic dialogue generation tasks and currently I’m struggling with the training process since the model only seems to predict START and END tokens in the final output layer irrespective of the target token given to the Transformer decoder. I’ve gone through the implementation multiple times and spotted and corrected some issues (mostly with the MultiHead Attention Layer and tokenization), however still haven’t had any luck.

I am using F.cross_entropy to compute the cross entropy between the final logits outputted from the transformer out[:, :-1:, :] and the target sequence in my dataset target[:, 1:]. The shifts are of course necessary since each output of the transformer corresponds to the next predicted token. I tried removing the START and END tokens in this loss function (i.e., out[:, :-2:, :] and target[:, 1:-1]) but this didn’t help either. The logits and targets are all shaped according to PyTorch documentation i.e., (batch_size, classes, sequence_length) and (batch_size, sequence_length) respectively with the target containing the class indices (the padding index is hence ignored). The training output looks something like this.

Epoch 0:   1%|          | 1/180 [00:05<16:28,  5.53s/it, loss=11, v_num=3, train_loss=11.00]
Epoch 0:   1%|          | 2/180 [00:25<37:55, 12.78s/it, loss=11, v_num=3, train_loss=11.00]
...
Epoch 5:  90%|█████████ | 162/180 [00:58<00:06,  2.77it/s, loss=5.54, v_num=3, train_loss=5.520]
Epoch 5:  90%|█████████ | 162/180 [00:58<00:06,  2.77it/s, loss=5.53, v_num=3, train_loss=5.430]

As seen above, the loss decays to a constant loss value between 5-6 and stays constant (even up to the 50th epoch). I printed out the probability tensors at each training step by softmax-ing the logits. Highest probabilities are attributed to the START and END tokens irrespective of the target token into the transformer decoder.

To confirm this behavior, I wrote a script to predict a response from the trained model (using beam search) given a context sequence and setting the first target token to [START]. No matter what context sequence I input into the model or what beam width I use, the next target token is always predicted to be [END]. I’m not sure if this has something to do with tokenization or some weights in the model exploding but I can’t seem to get rid of this behaviour. I even included dropout layers to eliminate the latter problem and still not luck. This issue persists even if I remove the emotional embeddings I am adding in the decoder.

Here is the full implementation of the Model for reference:

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size: int, heads: int) -> None:
        super().__init__()

        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = self.embed_size // self.heads

        assert self.head_dim * self.heads == self.embed_size

        self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)

        self.fc_out = nn.Linear(self.embed_size, self.embed_size, bias=False)

    def forward(
        self,
        keys: torch.Tensor, 
        values: torch.Tensor, 
        queries: torch.Tensor, 
        mask: torch.Tensor
    ) -> torch.Tensor:

        N = queries.shape[0]
        keys_len, values_len, queries_len = keys.shape[1], values.shape[1], queries.shape[1]

        values = self.values(values).reshape(N, values_len, self.heads, self.head_dim)
        keys = self.keys(keys).reshape(N, keys_len, self.heads, self.head_dim)
        queries = self.queries(queries).reshape(N, queries_len, self.heads, self.head_dim)

        scores = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        # Apply mask to attention scores if specified
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-1e20"))

        # Normalise with respect to all keys
        attention = F.softmax(scores / (self.embed_size ** 0.5), dim=-1)

        out = torch.einsum("nhqk,nvhd->nqhd", [attention, values])
        out = self.fc_out(out.reshape(N, queries_len, self.embed_size))

        return out


class TransformerBlock(nn.Module):
    def __init__(
        self,
        embed_size: int, 
        heads: int, 
        dropout: float, 
        forward_expansion: int
    ) -> None:

        super().__init__()

        self.attention = MultiHeadAttention(embed_size, heads)

        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.dropout = nn.Dropout(dropout)

        self.ff = nn.Sequential(
            nn.Linear(embed_size, embed_size * forward_expansion),
            nn.ReLU(),
            nn.Linear(embed_size * forward_expansion, embed_size)
        )

    def forward(
        self,
        keys: torch.Tensor, 
        values: torch.Tensor, 
        queries: torch.Tensor, 
        mask: torch.Tensor
    ) -> torch.Tensor:

        attention = self.attention(keys, values, queries, mask)

        contextualised = self.dropout(self.norm1(attention + queries))
        forward = self.ff(contextualised)
        out = self.dropout(self.norm2(forward + contextualised))

        return out

class Encoder(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        padding_idx: int,
        num_layers: int,
        embed_size: int,
        heads: int,
        dropout: float, 
        forward_expansion: int,
        max_seq_len: int,
        num_of_emo_labels: int
    ) -> None:

        super().__init__()

        self.word_embeddings = nn.Embedding(
            vocab_size + 1, embed_size, padding_idx=padding_idx)
        self.pos_embeddings = nn.Embedding(max_seq_len, embed_size)
        self.ds_embeddings = nn.Embedding(2 + 1, embed_size, padding_idx=0)

        self.layers = nn.ModuleList(
            [TransformerBlock(embed_size, heads, dropout, forward_expansion)
             for _ in range(num_layers)]
        )

        self.dropout = nn.Dropout(dropout)
    
    def forward(
        self, 
        context: torch.Tensor, 
        context_ds_state: torch.Tensor,
        mask: torch.Tensor,
        emotion_label: torch.Tensor
    ) -> torch.Tensor:

        N, seq_len = context.shape
        positions = torch.arange(0, seq_len, device=context.device).expand(N, seq_len)

        word_embeddings = self.word_embeddings(context)
        pos_embeddings = self.pos_embeddings(positions)
        ds_embeddings = self.ds_embeddings(context_ds_state)

        out = self.dropout(word_embeddings + pos_embeddings + ds_embeddings)

        for layer in self.layers:
            out = layer(out, out, out, mask)
        
        return out

class DecoderBlock(nn.Module):
    def __init__(
        self,
        embed_size: int,
        heads: int,
        dropout: float,
        forward_expansion: int
    ) -> None:

        super().__init__()

        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(
            embed_size,
            heads, 
            dropout, 
            forward_expansion
        )
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        keys: torch.Tensor,
        values: torch.Tensor,
        target_mask: torch.Tensor,
        input_mask: torch.Tensor
    ) -> torch.Tensor:
        
        attention = self.attention(x, x, x, target_mask)
        queries = self.dropout(self.norm(attention + x))
        out = self.transformer_block(keys, values, queries, input_mask)

        return out

class Decoder(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        padding_idx: int,
        num_layers: int,
        embed_size: int,
        heads: int,
        dropout: float, 
        forward_expansion: int,
        max_seq_len: int,
        num_of_emo_labels: int
    ) -> None:

        super().__init__()

        self.word_embeddings = nn.Embedding(
            vocab_size + 1, embed_size, padding_idx=padding_idx)
        self.pos_embeddings = nn.Embedding(max_seq_len, embed_size)
        self.ds_embeddings = nn.Embedding(2 + 1, embed_size, padding_idx=0)
        self.emotion_embedding = nn.Embedding(num_of_emo_labels, embed_size)

        self.layers = nn.ModuleList(
            [DecoderBlock(embed_size, heads, dropout, forward_expansion)
             for _ in range(num_layers)]
        )

        self.dropout = nn.Dropout(dropout)

        self.fc_out = nn.Linear(embed_size, vocab_size)

    def forward(
        self,
        target: torch.Tensor,
        target_ds_state: torch.Tensor,
        encoder_out: torch.Tensor,
        target_mask: torch.Tensor,
        input_mask: torch.Tensor,
        emotion_label: torch.Tensor
    ) -> torch.Tensor:

        N, seq_len = target.shape
        positions = torch.arange(0, seq_len, device=target.device).expand(N, seq_len)

        word_embeddings = self.word_embeddings(target)
        pos_embeddings = self.pos_embeddings(positions)
        ds_embeddings = self.ds_embeddings(target_ds_state)

        out = self.dropout(word_embeddings + pos_embeddings + ds_embeddings)
        
        for layer in self.layers:
            out = layer(out, encoder_out, encoder_out, target_mask, input_mask)
        
        emotion_embedding = self.emotion_embedding(
            emotion_label).unsqueeze(1).expand(-1, seq_len, -1)
        
        out = self.fc_out(out + emotion_embedding)

        return out

class Transformer(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        num_of_emo_labels: int,
        max_seq_len: int,
        padding_idx: int,
        num_layers: int = 6,
        embed_size: int = 256,
        heads: int = 8,
        dropout: float = 0.5, 
        forward_expansion: int = 4
    ) -> None:

        super().__init__()

        self.padding_idx = padding_idx
        self.encoder = Encoder(
            vocab_size,
            padding_idx,
            num_layers, 
            embed_size, 
            heads,
            dropout, 
            forward_expansion, 
            max_seq_len,
            num_of_emo_labels
        )

        self.decoder = Decoder(
            vocab_size,
            padding_idx,
            num_layers, 
            embed_size, 
            heads,
            dropout, 
            forward_expansion, 
            max_seq_len,
            num_of_emo_labels
        )

    def create_padding_mask(self, batch_seq):
        N = batch_seq.size(dim=0)
        padding_mask = (batch_seq != self.padding_idx).unsqueeze(1).unsqueeze(2)
        return padding_mask
    
    def create_lookahead_mask(self, batch_seq):
        N, seq_len = batch_seq.shape
        lookahead_mask = torch.tril(torch.ones(
            N, 1, seq_len, seq_len, device=batch_seq.device))
        return lookahead_mask
    
    def forward(
        self,
        context: torch.Tensor,
        target: torch.Tensor,
        context_ds_state: torch.Tensor,
        target_ds_state: torch.Tensor,
        emotion_label: torch.Tensor
    ) -> None:

        input_mask = self.create_padding_mask(context)
        target_mask = torch.minimum(
            self.create_lookahead_mask(target), 
            self.create_padding_mask(target)
        )

        encoder_out = self.encoder(
            context, 
            context_ds_state, 
            input_mask,
            emotion_label
        )
        out = self.decoder(
            target, 
            target_ds_state,
            encoder_out, 
            target_mask, 
            input_mask, 
            emotion_label
        )

        return out

I have used both Adam and AdamW as my optimizers with a StepLR scheduler if that’s relevant. I’ve been stuck on this problem for a while now so any help would be appreciated. Thanks in advance :slight_smile: