nn.Transformer is not working on Multi30k dataset

I have implemented simple model based on nn.Transformer and trained on Multi30k dataset. The final model always predicts the same output for different inputs. Can you please suggest what is wrong?

https://colab.research.google.com/drive/1KssnWND4dKx_aBTW0TQZlI4THttHiFTm?usp=sharing

class TransformerBase(nn.Module):
    def __init__(self, input_size, output_size, hidden_size=256):
        super(TransformerBase, self).__init__()
        
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
              
        self.transformer_model = nn.Transformer(d_model=hidden_size)
        self.embedding_input = nn.Embedding(self.input_size, hidden_size)
        self.embedding_output = nn.Embedding(self.output_size, hidden_size)
        self.pos_encoder = PositionalEncoding(hidden_size, 0.1)
        self.fc_out = nn.Linear(hidden_size, self.output_size)
        
    def forward(self, src, trg):
        embedded_input = self.pos_encoder(self.embedding_input(src) * math.sqrt(self.hidden_size))
        embedding_output = self.pos_encoder(self.embedding_output(trg) * math.sqrt(self.hidden_size))
        
        tgt_mask = self.transformer_model.generate_square_subsequent_mask(trg.shape[0])
        x = self.transformer_model(src=embedded_input, tgt=embedding_output, tgt_mask=tgt_mask.to(device))
        
        return self.fc_out(x)
1 Like

Currently have been experiencing this issue as well! I have doing an analysis with the decoder side, and I am pretty sure that the nn.TransformerDecoder is the root cause of this, Did you wanna try the following example?

#https://pytorch.org/docs/stable/generated/torch.nn.TransformerDecoder.html?highlight=nn%20transformerdecoder#torch.nn.TransformerDecoder
import torch
import torch.nn as nn
import math
pos_encoder = PositionalEncoding(d_model = 42,dropout=0.0)
torch.manual_seed(0)
memory = torch.rand(4, 1, 42) #src: (S, N, E) https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html
tgt = pos_encoder(torch.rand(4, 1, 42)) #tgt: (T, N, E)
tgtmask = torch.nn.Transformer().generate_square_subsequent_mask(4).float()

decoder_layer = nn.TransformerDecoderLayer(d_model=42, nhead=7, dropout =0.0)
transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(transformer_decoder.parameters(), lr = 0.05)

num_epochs = 1000

transformer_decoder.train()

for epoch in range(num_epochs):  
    
    out = transformer_decoder(tgt, memory, tgt_mask=tgtmask) 
    loss = criterion(out.permute(1,0,2), tgt.permute(1,0,2))
    print("predicted", out)
    print("target", tgt)
    # zeroes out the "old" parameter gradients and back prop
    optimizer.zero_grad()
    loss.backward()
    
    optimizer.step()

    print(f'Epoch {epoch + 1} | Loss for this epoch: {loss.item():.4f}')

@odats

I gave my code more time and it started to learn something. But it trains much slower and a final loss is much higher then it should be (if you compare with other implementations, a similar number of parameters, architecture, and learning rate)
@mathematicsofpaul

1 Like

Did you wanna upload your model to colab so that i could take a look? What did you end up changing?

I still use the same model and after some time it starts to produce better predictions. But it takes more time then I was expecting and the final result is worse than other implementations.
My implementation.
Good, working implementation. It converges much faster and final loss N times better.

Can you please suggest what can be wrong with my code, how should I debug the issue?

@mathematicsofpaul

1 Like