Transformer model doesn't improve even when fed the same single example over and over

I have created a very simple transformer model using PyTorch, but when I train the loss does not decrease during training as expected. I attempted to figure out where the cause was by feeding a single example to the transformer over and over again. I expected the transformer to quickly overfit, however what happens instead is that the loss does not decrease at all.

Here is my model and training code. As you can see I have basically a raw transformer model with linear layers acting as embedding layers.

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerModel(nn.Module):
    def __init__(self, nhead=8, dim_feedforward=1024, num_layers=6, dropout=0.1):
        super(TransformerModel, self).__init__()
        self.d_model = 512

        self.input_embedding1 = nn.Linear(objprocessor.MAX_INSTR_SIZE, self.d_model)
        self.output_embedding1 = nn.Linear(sourcenode.NODE_ID_END, self.d_model)

        self.pos_encoder = PositionalEncoding(self.d_model, dropout)

        self.transformer = nn.Transformer(d_model=self.d_model, nhead=nhead, dim_feedforward=dim_feedforward,
                                          num_encoder_layers=num_layers, num_decoder_layers=num_layers, dropout=dropout)

        self.output_linear = nn.Linear(self.d_model, sourcenode.NODE_ID_END)
        self.soft_max = nn.Softmax(dim=2)

    def forward(self, src, tgt):
        # src is a tensor of shape (S, N, objprocessor.MAX_INSTR_SIZE), where N = batch size,
        # S = sequence length of input, objprocessor.MAX_INSTR_SIZE = channel size
        # tgt is a tensor of shape (T, N, sourcenode.NODE_ID_END), where N = batch size,
        # T = sequence length of output, sourcenode.NODE_ID_END = channel size
        tgt_len = tgt.shape[0]

        #tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_len).cuda()
        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_len)

        src = self.input_embedding1(src)
        src = self.pos_encoder(src)

        tgt = self.output_embedding1(tgt)
        tgt = self.pos_encoder(tgt)

        output = self.transformer(src, tgt, tgt_mask=tgt_mask)
        output = self.output_linear(output)
        output = self.soft_max(output)
        return output

model = TransformerModel(dropout=0.0)
#model.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

model.train()
for (src, tgt) in batch_iterator(training_files):
    # Feed the same example to the transformer 1000 times to test if it's learning anything
    for _ in range(1000):
        #src = src.cuda()
        #tgt = tgt.cuda()
        optimizer.zero_grad()
        output = model(src, tgt)
        loss = criterion(output.view(-1, sourcenode.NODE_ID_END), torch.argmax(tgt.view(-1, sourcenode.NODE_ID_END), dim=1))
        print("Loss", loss)
        loss.backward()
        optimizer.step()

Output:

Loss tensor(6.2542, grad_fn=<NllLossBackward>)
Loss tensor(6.1925, grad_fn=<NllLossBackward>)
Loss tensor(6.1591, grad_fn=<NllLossBackward>)
Loss tensor(6.1591, grad_fn=<NllLossBackward>)
Loss tensor(6.1591, grad_fn=<NllLossBackward>)
Loss tensor(6.1590, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1589, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1584, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)
Loss tensor(6.1583, grad_fn=<NllLossBackward>)

When training on actual data my model exhibits similar behavior where the loss does not decrease, even when the transformer is presented with many examples.

Update:

I have no provided a standalone example of the problem that shows that the transformer doesn’t learn. I have simplified the model even more:

import torch
import torch.nn as nn
import random
import math

d_model = 512

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerModel(nn.Module):
    def __init__(self, nhead=8, dim_feedforward=1024, num_layers=6, dropout=0.1):
        super(TransformerModel, self).__init__()
        self.d_model = d_model
        self.output_embedding = nn.Linear(self.d_model, self.d_model)
        self.pos_encoder = PositionalEncoding(self.d_model, dropout)
        self.transformer = nn.Transformer(d_model=self.d_model, nhead=nhead, dim_feedforward=dim_feedforward,
                                          num_encoder_layers=num_layers, num_decoder_layers=num_layers, dropout=dropout)
        self.output_linear = nn.Linear(self.d_model, self.d_model)
        self.soft_max = nn.Softmax(dim=2)

    def forward(self, src, tgt):
        # src is a tensor of shape (S, N, d_model), where N = batch size,
        # S = sequence length of input, d_model = channel size
        # tgt is a tensor of shape (T, N, d_model), where N = batch size,
        # T = sequence length of output, d_model = channel size
        tgt_len = tgt.shape[0]

        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_len)

        src = self.pos_encoder(src)

        tgt = self.output_embedding(tgt)
        tgt = self.pos_encoder(tgt)

        output = self.transformer(src, tgt, tgt_mask=tgt_mask)
        output = self.output_linear(output)
        output = self.soft_max(output)
        return output

src_len = 30
tgt_len = 20
batch_size = 1

src = torch.randn((src_len, batch_size, d_model))
tgt = torch.zeros((tgt_len, batch_size, d_model))

for i in range(tgt_len):
    for j in range(batch_size):
        k = random.randrange(0, d_model)
        tgt[i, j, k] = 1.0

model = TransformerModel(dropout=0.0)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

# Feed the same example to the transformer 1000 times to test if it's learning anything
for _ in range(1000):
    optimizer.zero_grad()
    output = model(src, tgt)
    loss = criterion(output.view(-1, d_model), torch.argmax(tgt.view(-1, d_model), dim=1))
    difference = torch.sum(torch.abs(tgt - output))
    print("Loss", loss)
    print("Difference", difference)
    loss.backward()
    optimizer.step()

Output:

Loss tensor(6.2385, grad_fn=<NllLossBackward>)
Difference tensor(39.9306, grad_fn=<SumBackward0>)
Loss tensor(6.2116, grad_fn=<NllLossBackward>)
Difference tensor(38.8502, grad_fn=<SumBackward0>)
Loss tensor(6.1904, grad_fn=<NllLossBackward>)
Difference tensor(38.0000, grad_fn=<SumBackward0>)
Loss tensor(6.1903, grad_fn=<NllLossBackward>)
Difference tensor(38.0000, grad_fn=<SumBackward0>)
Loss tensor(6.1903, grad_fn=<NllLossBackward>)
Difference tensor(38., grad_fn=<SumBackward0>)
Loss tensor(6.1903, grad_fn=<NllLossBackward>)
Difference tensor(38., grad_fn=<SumBackward0>)
Loss tensor(6.1903, grad_fn=<NllLossBackward>)
Difference tensor(38., grad_fn=<SumBackward0>)
Loss tensor(6.1903, grad_fn=<NllLossBackward>)
Difference tensor(38., grad_fn=<SumBackward0>)
Loss tensor(6.1903, grad_fn=<NllLossBackward>)
Difference tensor(38., grad_fn=<SumBackward0>)
Loss tensor(6.1903, grad_fn=<NllLossBackward>)
Difference tensor(38., grad_fn=<SumBackward0>)

As you can see, the loss never decreases past 6.1903, and the absolute difference sum does not decrease past 38.

nn.CrossEntropyLoss uses F.log_softmax and nn.NLLLoss internally.
Since you are using softmax as your last activation, note that log_softmax will be applied on these activations again in the criterion.
Could you remove the softmax and rerun the code, please?

I removed the softmax layer as you suggested and the situation improved in that the loss and output are changing. However the output is still the same for all positions, as the following code/output log demonstrate. I have further simplified the model by removing everything except for the positional encoder and the transformer in the model, and changed the two input tensors into the transformer (input in and output in) to the zero tensor. Therefore the transformer should only be able to depend on the position information for determining what to give as output. Note that I am still computing the cross entropy loss against a fixed random index set.

Code:

import torch
import torch.nn as nn
import random
import math

d_model = 512

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerModel(nn.Module):
    def __init__(self, nhead=8, dim_feedforward=1024, num_layers=6, dropout=0.1):
        super(TransformerModel, self).__init__()
        self.d_model = d_model
        self.pos_encoder = PositionalEncoding(self.d_model, dropout)
        self.transformer = nn.Transformer(d_model=self.d_model, nhead=nhead, dim_feedforward=dim_feedforward,
                                          num_encoder_layers=num_layers, num_decoder_layers=num_layers, dropout=dropout)

    def forward(self, src, tgt):
        # src is a tensor of shape (S, N, d_model), where N = batch size,
        # S = sequence length of input, d_model = channel size
        # tgt is a tensor of shape (T, N, d_model), where N = batch size,
        # T = sequence length of output, d_model = channel size
        tgt_len = tgt.shape[0]
        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_len)
        src = self.pos_encoder(src)
        tgt = self.pos_encoder(tgt)
        output = self.transformer(src, tgt, tgt_mask=tgt_mask)
        return output

src_len = 30
tgt_len = 20
batch_size = 1

src = torch.zeros((src_len, batch_size, d_model))
tgt_in = torch.zeros((tgt_len, batch_size, d_model))

# Create tgt as a tensor of one hot vectors
tgt = torch.zeros((tgt_len, batch_size, d_model))

for i in range(tgt_len):
    for j in range(batch_size):
        k = random.randrange(0, d_model)
        tgt[i, j, k] = 1.0

max_tgt_indices = torch.argmax(tgt.view(-1, d_model), dim=1)

print("max_tgt_indices", max_tgt_indices)

model = TransformerModel(dropout=0.0)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
torch.set_printoptions(profile="full")

# Only used in the difference computation
soft_max = nn.Softmax(dim=2)

# Feed the same example to the transformer 1000 times to test if it's learning anything
for _ in range(1000):
    optimizer.zero_grad()
    output = model(src, tgt_in)
    print("Output argmax", torch.argmax(output.view(-1, d_model), dim=1))
    loss = criterion(output.view(-1, d_model), max_tgt_indices)
    difference = torch.sum(torch.abs(tgt - soft_max(output)))
    print("Loss", loss)
    print("Difference", difference)
    loss.backward()
    optimizer.step()
max_tgt_indices tensor([142, 488, 254, 241, 227, 244,  33, 101, 225, 496, 132, 386,  41, 259,
        468, 426, 425, 135, 147,  58])
Output argmax tensor([368, 368, 368, 368, 280, 280, 280, 368, 368, 368, 368, 368, 368, 368,
        368, 368, 368, 368, 368, 368])
Loss tensor(6.7258, grad_fn=<NllLossBackward>)
Difference tensor(39.9208, grad_fn=<SumBackward0>)
Output argmax tensor([425, 425, 425, 425, 425, 425, 425, 425, 425, 425, 425, 425, 425, 425,
        425, 425, 425, 425, 425, 425])
Loss tensor(3.4147, grad_fn=<NllLossBackward>)
Difference tensor(38.5672, grad_fn=<SumBackward0>)
Output argmax tensor([254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
        254, 254, 254, 254, 254, 254])
Loss tensor(3.4402, grad_fn=<NllLossBackward>)
Difference tensor(38.3624, grad_fn=<SumBackward0>)
Output argmax tensor([147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147,
        147, 147, 147, 147, 147, 147])
Loss tensor(3.1864, grad_fn=<NllLossBackward>)
Difference tensor(38.2882, grad_fn=<SumBackward0>)
Output argmax tensor([147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147,
        147, 147, 147, 147, 147, 147])
Loss tensor(3.1425, grad_fn=<NllLossBackward>)
Difference tensor(38.2429, grad_fn=<SumBackward0>)
Output argmax tensor([227, 227, 227, 227, 227, 227, 227, 227, 227, 227, 227, 227, 227, 227,
        227, 227, 227, 227, 227, 227])
Loss tensor(3.1301, grad_fn=<NllLossBackward>)
Difference tensor(38.2267, grad_fn=<SumBackward0>)
Output argmax tensor([227, 227, 227, 227, 227, 227, 227, 227, 227, 227, 227, 227, 227, 227,
        227, 227, 227, 227, 227, 227])
Loss tensor(3.1241, grad_fn=<NllLossBackward>)
Difference tensor(38.2256, grad_fn=<SumBackward0>)
Output argmax tensor([425, 425, 425, 425, 425, 425, 425, 425, 425, 425, 425, 425, 425, 425,
        425, 425, 425, 425, 425, 425])
Loss tensor(3.1146, grad_fn=<NllLossBackward>)
Difference tensor(38.2101, grad_fn=<SumBackward0>)
Output argmax tensor([425, 425, 425, 425, 425, 425, 425, 425, 425, 425, 425, 425, 425, 425,
        425, 425, 425, 425, 425, 425])
Loss tensor(3.1082, grad_fn=<NllLossBackward>)
Difference tensor(38.2040, grad_fn=<SumBackward0>)
Output argmax tensor([142, 142, 142, 142, 142, 142, 142, 142, 142, 142, 142, 142, 142, 142,
        142, 142, 142, 142, 142, 142])
Loss tensor(3.1025, grad_fn=<NllLossBackward>)
Difference tensor(38.1921, grad_fn=<SumBackward0>)
Output argmax tensor([244, 244, 244, 244, 244, 244, 244, 244, 244, 244, 244, 244, 244, 244,
        244, 244, 244, 244, 244, 244])
Loss tensor(3.0988, grad_fn=<NllLossBackward>)
Difference tensor(38.1850, grad_fn=<SumBackward0>)
Output argmax tensor([147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147, 147,
        147, 147, 147, 147, 147, 147])
Loss tensor(3.0926, grad_fn=<NllLossBackward>)
Difference tensor(38.1792, grad_fn=<SumBackward0>)

Basically the output jumps around to all constants, where the constant being output happens to be right exactly once. This is what I would expect if the positional encoding was useless since the transformer model is position agnostic without the encoding. Also note that the output from the randomly initialized weights did produce some outputs for different positions (236, 280), which seems good in my opinion.

I double checked the positional encoding class and its output appears to be correct.

I personally have some doubts about this particular encoding scheme. For one, I don’t understand why the positional encoding isn’t just concatenated as extra channels instead of added onto the input/output in. I am willing to set this aside for now based on the belief that the authors knew what they were doing.

Thank you for taking the time to answer questions! Really appreciated :slightly_smiling_face:

Edit: Good post explaining the position encoding here: https://kazemnejad.com/blog/transformer_architecture_positional_encoding/

Edit 2: I ran my original code (the one that uses actual training data) and see similar behaviour

Edit 3: adding dropout back in seems to resolve the issue where all the outputs are the same. Convergence still seems rather slow though I don’t know if this should be expected. Perhaps a larger batch size would improve things as well.

@calebh I believe that there is something going on with the tgt_mask mechanisms with the generate_square_subsequent_mask or something related to that matter. Clearly, if we are getting repeated tokens then that means that it is taking the whole sentence of words/embeddings as opposed to taking them sequentially with the masking.

I have exactly the same issue and don’t know what to do about it, have you been able to solve it?

I never resolved the issue, sorry. Please post back here if you figure it out.