Wrote a transformer from scratch but the loss isn't decreasing

Hi there,

So I followed this tutorial to implement the transformer architecture from the “Attention Is All You Need” paper. I had to change the code in the tutorial a bit as it had some mistakes. I am using this model for a Neural Machine Translation task but my loss isn’t decreasing and is always staying within the range of 5 - 5.7. My input and target tensors are in the form of (batch_size, seq_len). Here is an example of a toy dataset:

src = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]])
trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]])

where 1 & 2 are the <sos> & <eos> tokens respectively and 0 is the <pad> token.
Here is my code for the transformer:

class AttentionHead(nn.Module):
    def __init__(self, emb_dim, dim_kqv):
        super(AttentionHead, self).__init__()
        self.dim_kqv = dim_kqv
        
        self.wq = nn.Linear(emb_dim, dim_kqv)
        self.wk = nn.Linear(emb_dim, dim_kqv)        
        self.wv = nn.Linear(emb_dim, dim_kqv)
        
    def forward(self, q, k, v, mask):
        queries = self.wq(q)
        keys = self.wk(k)
        values = self.wv(v)
        
        score = queries.bmm(keys.transpose(1, 2))     
          
        score = torch.div(score, self.dim_kqv ** 0.5, rounding_mode='floor')
        
        if mask is not None:
            score = score.masked_fill(mask == 0, -1e9)
        
        softmax = F.softmax(score, dim = -1)

        return softmax.bmm(values)
    
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, emb_dim, dim_kqv):
        super(MultiHeadAttention, self).__init__()
        self.heads = nn.ModuleList(
            [AttentionHead(emb_dim, dim_kqv) for _ in range(num_heads)]
        )
        
        self.w0 = nn.Linear(num_heads * dim_kqv, emb_dim)
        
    def forward(self, q, k, v, mask):
        attentions = [h(q, k, v, mask) for h in self.heads]
        attentions = torch.cat(attentions, dim = -1)
        out = self.w0(attentions)
        
        return out

class Residual(nn.Module):
    def __init__(self, sublayer, dimension, dropout):
        super(Residual, self).__init__()
        self.sublayer = sublayer
        self.norm = nn.LayerNorm(dimension)
        self.dropout = nn.Dropout(dropout)

    def forward(self, *tensors):
        return self.dropout(self.norm(tensors[0] + self.sublayer(*tensors)))
    
class FeedForward(nn.Module):
    def __init__(self, emb_dim, ff_dim):
        super(FeedForward, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(emb_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, emb_dim)
        )
        
    def forward(self, residual_out):
        return self.network(residual_out)
    

class EncoderLayer(nn.Module):
    def __init__(self, emb_dim, num_heads, ff_dim, dropout):
        super(EncoderLayer, self).__init__()
        self.dim_kqv = emb_dim // num_heads
        
        assert (self.dim_kqv * num_heads == emb_dim), "Embedding size must be divisible by number of heads" 
        
        self.attention = Residual(
            MultiHeadAttention(num_heads, emb_dim, self.dim_kqv),
            dimension=emb_dim,
            dropout=dropout,
        )
    
        self.feed_forward = Residual(
            FeedForward(emb_dim, ff_dim),
            dimension=emb_dim,
            dropout=dropout,
        )
        
    def forward(self, src, mask):
        src = self.attention(src, src, src, mask)
        out = self.feed_forward(src)
        return out
    
class Encoder(nn.Module):
    def __init__(self, 
                 emb_dim, 
                 num_heads, 
                 ff_dim, 
                 num_layers, 
                 src_vocab_size,
                 padding_index,
                 dropout):
        
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([
            EncoderLayer(emb_dim,
                         num_heads,
                         ff_dim,
                         dropout)
            for _ in range(num_layers)
        ])
        
        self.embedding = nn.Embedding(src_vocab_size, emb_dim, padding_idx=0)
        self.pe = PositionalEncoder(emb_dim)
        
    def forward(self, src):
        src = self.embedding(src)
        
        src = self.pe(src)
        
        for layer in self.layers:
            src = layer(src, None)
            
        return src
    
class DecoderLayer(nn.Module):
    def __init__(self, emb_dim, num_heads, ff_dim, dropout):
        super(DecoderLayer, self).__init__()
        
        self.dim_kqv = emb_dim // num_heads
        
        assert (self.dim_kqv * num_heads == emb_dim), "Embedding size must be divisible by number of heads"
        
        self.attention_1 = Residual(
            MultiHeadAttention(num_heads, emb_dim, self.dim_kqv),
            dimension=emb_dim,
            dropout=dropout
        )
        
        self.attention_2 = Residual(
            MultiHeadAttention(num_heads, emb_dim, self.dim_kqv),
            dimension=emb_dim,
            dropout=dropout
        )
        
        self.feed_forward = Residual(
            FeedForward(emb_dim, ff_dim),
            dimension=emb_dim,
            dropout=dropout
        )
        
    def forward(self, trg, memory, mask):
        query = self.attention_1(trg, trg, trg, mask)
        attentions = self.attention_2(query, memory, memory, None)
        out = self.feed_forward(attentions)
        
        return out

class Decoder(nn.Module):
    def __init__(self, 
                 emb_dim, 
                 num_heads, 
                 ff_dim, 
                 num_layers, 
                 out_size, 
                 padding_index,
                 dropout):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList([
            DecoderLayer(emb_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])
        
        self.embedding = nn.Embedding(out_size, emb_dim, padding_idx=padding_index)
        self.pe = PositionalEncoder(emb_dim)

    def make_trg_mask(self, trg):
        batch_size, seq_len = trg.shape[0], trg.shape[1]
        mask = torch.tril(torch.ones(batch_size, seq_len, seq_len))
        return mask
        
    def forward(self, trg, encoder_out):
        trg = self.embedding(trg)
        
        trg = self.pe(trg)

        mask = self.make_trg_mask(trg).to(trg.get_device())
        
        for layer in self.layers:
            trg = layer(trg, encoder_out, mask)
            
        # return self.lin(trg)
        return trg

class VanillaTransformer(nn.Module):
    def __init__(self, 
                 emb_dim, 
                 num_heads, 
                 ff_dim, 
                 num_layers, 
                 src_vocab_size, 
                 trg_vocab_size,
                 device,
                 padding_index,
                 dropout):
        super(VanillaTransformer, self).__init__()
        
        self.encoder = Encoder(emb_dim, 
                               num_heads, 
                               ff_dim,
                               num_layers, 
                               src_vocab_size,
                               padding_index,
                               dropout).to(device)
        
        self.decoder = Decoder(emb_dim,
                               num_heads,
                               ff_dim, 
                               num_layers,
                               trg_vocab_size,
                               padding_index,
                               dropout).to(device)

        self.lin = nn.Linear(emb_dim, trg_vocab_size)
        
        
    def forward(self, src, trg):
        encoder_out = self.encoder(src)

        decoder_out = self.decoder(trg, encoder_out)
        
        out = self.lin(decoder_out)

        return out

And here is the training loop:

model = VanillaTransformer(EMBEDDING_DIM = 256,
                           NUM_HEADS = 8,
                           FF_DIM = 2048,
                           NUM_LAYERS = 6,
                           SOURCE_VOCAB_SIZE = 10,
                           TARGET_VOCAB_SIZE=10,
                           device,
                           padding_index = 0,
                           DROPOUT= 0.1).to(device)


optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_func = nn.CrossEntropyLoss(ignore_index=padding_index)

for epoch in tqdm(range(NUM_EPOCHS), position = 0, leave = True):
    
    running_loss = 0.0
    model.train()
    
    for batch_index, batch in enumerate(train_loader):
        optimizer.zero_grad()
        
        src = batch['x_source']
        trg = batch['y_target']
        
        y_pred = model(src.to(device), trg[:, :-1].to(device))
        y_pred = y_pred.reshape(-1, y_pred.shape[2])
        
        loss = loss_func(y_pred, trg[:, 1:].reshape(-1).to(device))
        loss.backward()

        running_loss += (loss.item() - running_loss) / (batch_index + 1)
        optimizer.step()
        
    if epoch == 0 or (epoch + 1) % PRINT_EVERY == 0:
      print('Epoch: {:<2} Train loss: {:0.4f}'.format(epoch + 1 , running_loss))

I would also appreciate it if you could let me know if the way I am sending the input to the decoder and how its being compared to targets is correct. I understand that the amount of code here could be overwhelming therefore, I would suggest first reading the article in the link above. Its a great read and everything there is clearly explained and it will make understanding my code much easier.

Been at this for 2 days now so any help is appreciated. Thanks :slightly_smiling_face:

Ok so I figured it out. My learning rate was way too high apparently so setting it to something like 3e-4 fixed the issue. And I was also able to verify that the way I am sending in the input to the decoder is correct.