Transformer is not learning on sequence to sequence task

I have implemented Transformer model based on nn.Transformer
Unfortunately, my model is not learning. I have LSTM based networks that show good learning on the same sequence to sequence dataset. Can you please suggest what can be wrong?

class TransformerBase(nn.Module):
    def __init__(self, input_size, output_size, hidden_size=256):
        super(TransformerBase, self).__init__()
        
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
              
        self.transformer_model = nn.Transformer(d_model=hidden_size)
        self.embedding_input = nn.Embedding(self.input_size, hidden_size)
        self.embedding_output = nn.Embedding(self.output_size, hidden_size)
        self.pos_encoder = PositionalEncoding(hidden_size, 0.1)
        self.fc_out = nn.Linear(hidden_size, self.output_size)
        
    def forward(self, src, trg):
        embedded_input = self.pos_encoder(self.embedding_input(src) * math.sqrt(self.hidden_size))
        embedding_output = self.pos_encoder(self.embedding_output(trg) * math.sqrt(self.hidden_size))
        
        tgt_mask = generate_square_subsequent_mask(trg.shape[0])
        x = self.transformer_model(src=embedded_input, tgt=embedding_output, tgt_mask=tgt_mask.to(device))
        
        return self.fc_out(x)

def generate_square_subsequent_mask(sz: int):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)
def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        # Get input and targets and get to cuda
        src = batch.src.to(device)
        trg = batch.trg.to(device)
        
        optimizer.zero_grad()
        
        output = model(src, trg[:-1])
        
        output = output.reshape(-1, output.shape[-1]).contiguous()
        trg = trg[1:].reshape(-1).contiguous()
        
        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

@odats Can you explain to me what exactly is happening? Are you getting repeated tokens during training?

Model converges to the almost (I think it is due to dropouts) the same prediction for most inputs. Loss slowly decreases to some stable high point. LSTM based models achieve N times better score.

Do you think the problem might be in the model definition or training process?

Can post some examples of your predictions outputs? from the training phase is fine.

training:

tensor([[  2.4237,   2.5905,   2.2169, -10.5635,  -1.0118, -10.1921],
        [  3.0716,   3.8252,   3.1995,  -8.1787,  -4.2393,  -8.7988],
        [  2.8787,   3.6345,   2.8708,  -8.3596,  -4.4989,  -8.4322],
        ...,
        [  1.5031,   1.0884,   1.2739,  -9.6650,   2.4617,  -9.0367],
        [  1.3291,   1.5163,   1.4576, -10.0381,   2.1117,  -9.5944],
        [  0.9889,   1.2590,   1.0715,  -9.5567,   2.5086,  -9.4345]],
       grad_fn=<ViewBackward>)
tensor([[  3.1798,   3.6515,   2.7826,  -8.4505,  -4.2253,  -8.5992],
        [  3.2187,   3.5649,   2.8608,  -8.9943,  -3.8257,  -8.1756],
        [  2.8734,   3.4187,   2.6570,  -7.9880,  -4.0824,  -8.3490],
        ...,
        [  1.0231,   1.5060,   0.9266,  -9.2981,   2.0333,  -9.3198],
        [  1.3715,   1.2931,   1.0752,  -9.4257,   2.2230,  -9.2558],
        [  1.7830,   1.6602,   1.5629, -10.3801,   1.7262, -10.1634]],
       grad_fn=<ViewBackward>)
tensor([[  3.0281,   3.7362,   3.0541,  -8.4012,  -4.0972,  -8.8399],
        [  2.5452,   3.5018,   2.7213,  -7.9301,  -3.7148,  -8.4083],
        [  2.9479,   3.6801,   2.8139,  -8.8359,  -3.9234,  -8.8597],

prediction loop:

SOS=3, EOS=4, PAD=5

def show_predictions_transformer(model, loader):
    model.eval()
    
    with torch.no_grad():
        for source, target in loader:

            target_max_len = 12
            outputs = torch.zeros(target_max_len, dtype=torch.long)
            outputs[0] = torch.LongTensor([SOS])

            for i in range(1, target_max_len):
                pred = model(source, outputs[:i].unsqueeze(1))
                outputs[i] = pred[-1].argmax(1)
                if outputs[i] == EOS:
                    #print('eos')
                    break

            print('targ', target.squeeze())
            print('pred', outputs)
targ tensor([3, 1, 1, 2, 1, 0, 2, 4])
pred tensor([3, 1, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0])
Epoch: 04 | Time: 0m 27s
	Train Loss: 1.180 | Train PPL:   3.254
	 Val. Loss: 1.104 |  Val. PPL:   3.017
targ tensor([3, 2, 2, 2, 0, 1, 1, 4])
pred tensor([3, 1, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0])
Epoch: 05 | Time: 0m 27s
	Train Loss: 1.176 | Train PPL:   3.240
	 Val. Loss: 1.131 |  Val. PPL:   3.098
targ tensor([3, 2, 1, 1, 2, 1, 0, 4])
pred tensor([3, 1, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0])
Epoch: 06 | Time: 0m 26s
	Train Loss: 1.166 | Train PPL:   3.210

Your mask might be off? I’m not entirely sure, but it seems that you are are setting the trg[:-1] to be all but the last value.

It’s an old issue but I think I found something related to this issue recently.

When I looked into my site-packages/torch/nn/functional.py, which I assume to be the path to MultiheadAttention function, I found that the attention scores are not divided by, as suggested in the Transformer paper, the square root of attention dimension. As the reason to this division in the paper suggest, this may saturate softmax function.