Transformer translation task: interaction between norm bias and attention bias causing divergence

Hello.

I’m training a transformer to replicate “attention is all you need” for german to english translation and I’ve found that training starts to diverge pretty quickly. I wanted to start with a baseline, so I was following the pytorch transformer tutorial. I.e. the architecture is 2 layers of an encoder, and then the decoder is just a linear layer and a softmax.

Here is an abridged version of the architecture:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model, device=device)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        # Logic above assumes [seq_len, batch_size, embedding_dim]
        # Swap axes for batch_size first, then swap back. 
        x = torch.swapaxes(x, 1, 0)
        x = x + self.pe[:x.size(0)]
        x = torch.swapaxes(x, 0, 1)        
        return self.dropout(x)  
    
    
class BaselineTransformer(nn.Module): 
    def __init__(self, 
                 model_dim=512, 
                 nhead=8, 
                 num_encoder_layers=6,
                 num_decoder_layers=6, 
                 output_vocab_size=10000, 
                 padding_idx=0, 
                 dropout=0.0): 
        super(BaselineTransformer, self).__init__()
        self.model_dimension = model_dim
        self.nhead = nhead
        self.num_encoder_layers = num_encoder_layers
        self.num_decoder_layers = num_decoder_layers
        self.output_vocab_size = output_vocab_size
        
        pretrained_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-de-en")
        
        self.encoder_input_embeddings = pretrained_model.get_encoder().embed_tokens.to(device)
        
        self.pos_encoder = PositionalEncoding(model_dim, dropout=dropout)
        encoder_layers = TransformerEncoderLayer(d_model=model_dim, nhead=nhead, batch_first=True, dropout=dropout, norm_first=config['norm_first'])
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_encoder_layers).to(device)#, encoder_norm).to(device)
    
        self.target_lm_head = nn.Linear(self.model_dimension, self.output_vocab_size, device=device)
        self.log_softmax = torch.nn.LogSoftmax(dim=2)   
        self.init_weights()
        
    def init_weights(self):
        initrange = 0.1
        self.encoder_input_embeddings.weight.data.uniform_(-initrange, initrange)
        self.target_lm_head.weight.data.uniform_(-initrange, initrange)
        self.target_lm_head.bias.data.uniform_(math.log(1 / self.output_vocab_size) - initrange, math.log(1 / self.output_vocab_size) + initrange)
        
    def forward(self, source, target, tokenizer, 
                src_key_padding_mask=None, tgt_key_padding_mask=None, tgt_mask=None, src_mask=None):
        source_embeddings_raw = self.encoder_input_embeddings(source) * math.sqrt(self.model_dimension)
        source_embeddings = self.pos_encoder(source_embeddings_raw)        
        memory = self.transformer_encoder(
            source_embeddings, 
            src_key_padding_mask=src_key_padding_mask,
            mask=src_mask
        ) 
        
        logits = self.target_lm_head(memory)
        res = self.log_softmax(logits)   
        return res

It looks like training loss reaches some low level (≈ 7.5 for NLL) and starts to diverge rapidly thereafter. It also appears as if there is some interaction between the bias of the layer norm and the bias of the attention operation.

Here is the loss: link
Image:

Here are some charts plotting the parameter values evolving over training:

And here are some annotations of the chart that indicate an interaction between norm bias and attention bias (though not positive!)

I’m also using 1/5 the learning rate of the “attention is all you need” paper, so a bit surprised about the training divergence.

Finally, right around when the divergence starts happening, the attention weights of the first training layer get nuked, and it seems like it just starts paying attention to the same tokens for every other token (note the rows sum to 1 in this heatmap):



Any idea on what might be happening? This model doesn’t perform much better than an input-independent baseline (which reaches 8.5 NLL whereas this model hits around 7 before diverging).

Also, here’s the hyper parameters used for training in case that helps.

@mrich This can happen with any NLL based loss optimizer.

  • You can introduce smart learning_rate reduction when you find that the loss is not reducing for 2 to 3 epochs, otherwise the graph can pick up some incorrect gradient and move in that direction

  • You can reduce the batch_size to 16 or 32 as this allows the model to train in a more generic direction rather that specific after every epoch

  • You can switch over from advanced optimizers like Adam etc to SGD after a certain reduction of loss so that the optimizer efficiently manoeuvres the gradient drops

This exercise is more of an art than science so you will only get tips. Good luck

1 Like

Thank you @anantguptadbl!!

Just wanted to follow up here; I’ve only tried your second suggestion so far (changing batch size to 32), but it seems to lead to much better results. Really do appreciate your help; was stuck on this for a little while! Probably should’ve tried some sort of sweep over hyperparameters. Anyway, I’ve only ran this for a few hours so I’m sure it will get much better results later.

For anyone looking at this in the future:

Losses / Validation Losses

Histogram of values for intermediate values (memory = encoder output, predicted embedding = decoder output)

Last Attention Layers
Screen Shot 2022-02-02 at 8.34.41 AM

Screen Shot 2022-02-02 at 8.35.04 AM

Will follow up with a more detailed write up via a blog post s.t. others can replicate later. Appreciate your help again.

1 Like