nn.TransformerEncoder-based language model always generates the most common token

I’m trying to do language modeling on a custom dataset using nn.TransformerEncoder. I’m using https://github.com/pytorch/examples/tree/master/word_language_model as a reference.

Previously, I used Google’s Trax and its TransformerLM model to train a transformer with this dataset, based on this example: https://github.com/jalammar/jalammar.github.io/blob/master/notebooks/Trax_TransformerLM_Intro.ipynb. There, I reduced the Adam learning rate to 1e-04, replaced the data with my actual dataset, adjusted the hyperparameters for my use-case, and managed to get very good results.

I’m now trying to replicate the same result in PyTorch, but without much luck.

Here are the things I modified from the word_language_model example above:

  • Changed the batching and dataset load logic. Since my dataset consists of separate sentences, my input data is of shape (num_examples, seq_len), and each element is a token index, with 0 reserved for padding. I generate inputs of shape (max_seq_len, batch_size).
  • Calculated a padding mask that contains True wherever the original data has 0 (the padding token), passing it as src_key_padding_mask to nn.TransformerEncoder.
  • Added an Adam optimizer instead of the previously inline LR annealing on plateau.
  • Removed F.log_softmax, so now returning raw data from the last nn.Linear layer, and changed the loss to nn.CrossEntropyLoss(ignore_index=0).
  • Added accuracy measurements by masking the final logits using the key padding mask, taking an argmax, and comparing to the actual output.

My training plateaus within a few dozen batches to an accuracy of ~10% and a constant loss, and simply predicts the most common overall token, no matter the input (this gives an accuracy of ~10%, as that token is around 10% of the training data). I’ve tried using a learning rate scheduler, adding some warmup steps, played around with hyperparameters and the initial learning rate, tried changing the initialization of the embedding and final nn.Linear weights to nn.init.xavier_uniform_, but nothing helps.

For comparison, I tried randomly shuffling my data around to make it nonsense, and the model arrives at similar numbers. So I’m pretty sure it’s learning nothing.

Here’s an example of how my training data looks (right before it goes into the model):

data =
       [[  1,   1,   1,   1],
        [  2,   2,   2,   2],
        [  3, 107,  81,  81],
        [115,   4, 111,  46],
        [  5,   5,   5,   5],
        [ 80,  41,  63,  61],
        # some zeroes here, starting at different rows,
        # depending on the size of the example
        ...]]

target =
    [  2,   3, 115,   5,  80,  42,  ...
       0,   0,   0,   0,   0,  0,   ...
       2, 107,   4,   5,  41,  96,  ...
       0,   0,   0,   0,   0,  0,   ...]]

(Almost every example starts with 1 2, so I imagined it’d be easy to fit on at least this, but not even that happens.)

Here’s the model:

# PositionalEncoding definition from word_language_model omitted
# for brevity

class Transformer(nn.Module):
    def __init__(self, n_tokens, d_model, n_heads, d_ff, n_layers, dropout=0.1,
            max_len=4096, activation='relu'):
        super(Transformer, self).__init__()

        self.mask = None
        self.d_model = d_model

        self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)
        self.embedding = nn.Embedding(n_tokens, d_model, padding_idx=0)
        enc_layer = nn.TransformerEncoderLayer(d_model, n_heads, d_ff, dropout)
        self.tf_encoder = nn.TransformerEncoder(enc_layer, n_layers)
        self.decoder = nn.Linear(d_model, n_tokens)

        self.init_weights()

    def _from_binary_mask(self, mask):
        return mask.float() \
            .masked_fill(mask == False, float('-inf')) \
            .masked_fill(mask == True, float(0.0))

    def _generate_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        return self._from_binary_mask(mask)

    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.embedding.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.weight)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)
        # Alternative I tried:
        # nn.init.xavier_uniform_(self.embedding.weight)
        # nn.init.xavier_uniform_(self.decoder.weight)
        # nn.init.normal_(self.decoder.bias, 1e-6)

    def forward(self, src, use_mask=True, src_key_padding_mask=None):
        if use_mask:
            device = src.device
            if self.mask is None or self.mask.size(0) != len(src):
                mask = self._generate_mask(len(src)).to(device)
                self.mask = mask
        else:
            self.mask = None

        # embed and add positional information
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)

        output = self.tf_encoder(src, self.mask,
                src_key_padding_mask=src_key_padding_mask)
        output = self.decoder(output)
        return output

And here’s my training loop:

model = model.Transformer(n_tokens, args.d_model, args.n_heads, args.d_ff,
        args.n_layers, args.dropout).to(device)

optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-3)

# init data, etc.

def train(train_data):
    model.train()
    bs = args.batch_size

    total_loss = 0.0
    total_correct = 0
    total_total = 0

    start_time = time.time()
    for batch_i in range(0, train_data.size(0) // bs):
        transposed_data, target = get_batch(train_data, batch_i) # (bs, seq_len), (bs*seq_len)
        data = transposed_data.transpose(0, 1).type(torch.LongTensor).to(device) # (seq_len, bs)
        target = target.type(torch.LongTensor).to(device) # (bs*seq_len)

        optimizer.zero_grad()
        key_padding_mask = calc_key_padding_mask(transposed_data, target) # (bs, seq_len), 1 if not padding
        output = model(data, src_key_padding_mask=torch.logical_not(key_padding_mask))
        output = output.view(-1, n_tokens) # (bs*seq_len, n_tokens)

        logits = F.log_softmax(output, dim=-1) # (bs*seq_len, n_tokens)
        masked_logits = apply_mask_to_logits(logits, key_padding_mask) # (bs*seq_len, n_tokens), sets element to -inf if padding
        confidences, predictions = torch.max(masked_logits.exp(), 1) # (bs*seq_len), confidence will be 0 if padding
        total_correct += (torch.logical_and(predictions == target, confidences > 0.0)).float().sum()
        total_total += key_padding_mask.float().sum()

        e = loss(output, target)
        e.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
        optimizer.step()

        total_loss += e.item()

        if batch_i % args.log_interval == 0 and batch_i > 0:
            curr_loss = total_loss / args.log_interval / bs
            curr_accuracy = total_correct / total_total
            total_loss = 0.
            total_correct = 0
            total_total = 0

            # print the stats...

Please help - I’m at a loss (no pun intended). I feel like I’ve tried everything.