Transformer won't learn test sample

Hi! I want to experiment with a simple transformer model and as a test, I am trying to make it overfit a single sample. Although I have followed the tutorial and most of my code is copy-pasted, the model can’t even learn 1 sample. I looked at other similar questions on the forum but none solved the issue. Here is my code:

def _create_padding_mask(inp: torch.Tensor, padding_idx: int):
    return (inp == padding_idx).bool()

class simpleTransformer(nn.Module):
    def __init__(self, emb_dim, vocab_size_src, vocab_size_tgt, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, pos_dropout, trans_dropout, tgt_padding_idx, streamFormat, max_seq_length = 1000):
        self.embed_src = nn.Embedding(vocab_size_src, emb_dim)
        self.emb_dim = emb_dim
        self.tgt_mask = None
        self.tgt_padding_idx = tgt_padding_idx
        self.embed_tgt = nn.Embedding(vocab_size_tgt, emb_dim)
        self.pos_enc = PositionalEncoding(emb_dim, pos_dropout, max_seq_length)

        self.transformer = nn.Transformer(emb_dim, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, trans_dropout)
        self.fc = nn.Linear(emb_dim, vocab_size_tgt)

    def forward(self, src, tgt): #, tgt_key_padding_mask, memory_key_padding_mask,
        if self.tgt_mask is None:
            self.tgt_mask = self._create_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        tgt_padding_mask = _create_padding_mask(tgt, self.tgt_padding_idx).to(tgt.device) 
        src = self.pos_enc(self.embed_src(src).transpose(0, 1) * math.sqrt(self.emb_dim))
        tgt = self.pos_enc(self.embed_tgt(tgt).transpose(0, 1) * math.sqrt(self.emb_dim))

        output = self.transformer(src, tgt, tgt_mask = self.tgt_mask, tgt_key_padding_mask = tgt_padding_mask) #src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask

        return self.fc(output)
    def _create_square_subsequent_mask(self, dim):
        """Target mask. Prevents decoder from attending to future positions."""

        mask = (torch.triu(torch.ones(dim, dim)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    print("List of improvements: \n padding mask \n beamsearch")

class PositionalEncoding(nn.Module):
    """Absolute Positional encoding. Slightly modified version from a pytorch example. """
    def __init__(self, emb_dim, dropout=0.1, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, emb_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, emb_dim, 2).float() * (-math.log(10000.0) / emb_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)[:, :(pe[:, 1::2]).size(1)] = pe.unsqueeze(0).transpose(0, 1)
        # self.register_buffer('pe', pe)

    def forward(self, x):
        x = x +[:x.size(0), :, :]
        return self.dropout(x)

The model systematically outputs the same label. I have checked the attention mask already, but maybe I wired it incorrectly? Do you see something wrong with the model itself?

It seems right, but you didn’t attach your training code.
Do you initialize the optimizer with the transformer params? Do you use .backward on your loss? Do you perform an optimizer step? Etc…

Here is essentially the training code, I removed a few non-essential things for clarity.

class Trainer:
    def __init__(self, model, train_ds, val_ds,
                 train_bs, val_bs, ignore_idx, dest_path):
        self.model = model
 # optimizer and scheduler are initialized in the train() method
        self.opti = None
        self.scheduler = None

        self.ignore_idx = ignore_idx
        self.loss = nn.CrossEntropyLoss(ignore_index=self.ignore_idx)
        self.train_dl = DataLoader(train_ds, train_bs)
        self.val_dl = DataLoader(val_ds, val_bs)

    def train(self, epochs: int, start_lr: float, lr_scheduler='no', **kwargs):

        possibleOptimizers = {
         "Adam" : optim.Adam(self.model.parameters(), start_lr),
         "SGD" : optim.SGD(self.model.parameters(), lr=start_lr, momentum=0.9),
         "Adadelta" : optim.Adadelta(self.model.parameters(), lr=start_lr),
         "Adagrad" : optim.Adagrad(self.model.parameters(), lr=start_lr),
         "Adamax": optim.Adamax(self.model.parameters(), lr=start_lr),
         "ASGD" : optim.ASGD(self.model.parameters(), lr=start_lr),
         "LBFGS" : optim.LBFGS(self.model.parameters(), lr=start_lr),
         "RMSprop" : optim.RMSprop(self.model.parameters(), lr=start_lr)

        self.opti = possibleOptimizers.get(kwargs.get("optimizer", "Adam"))

        #initialize scheduler
        lr_factor = kwargs.get('lr_factor', 0.5)
        patience = kwargs.get('patience', 5)
        reduce_every = kwargs.get('reduce_every', 30)
        step_size_up = kwargs.get('step_size_up', 400)
        self.init_scheduler(lr_scheduler, lr_factor=lr_factor, patience=patience, n_epochs=epochs,
                            steps_per_epoch=len(self.train_dl), base_lr = start_lr*0.5, max_lr=start_lr*15, step_size_up = step_size_up, reduce_at_ep=reduce_every)

# Drawing one random sample and keeping it for rest of training.
        test_sample = iter(self.train_dl).next()
        for epoch in range(epochs):
            for i, train_batch in enumerate(self.train_dl):
                #only one sample for testing
                midis, chordseqs = test_sample['midi'], test_sample['chords']
                #shift tgt sequences so that inps start with BOS token
                chordseqs_tgt, chordseqs_inp = chordseqs[:, 1:], chordseqs[:, :chordseqs.size()[1]-1]
                preds = self.model(midis, chordseqs_inp) 
                train_loss = self.loss(preds.view(-1, preds.shape[2]),
                #torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
                if lr_scheduler == 'one_cycle' or lr_scheduler == 'cyclic':

                per_epoch_train_loss += train_loss.item()

All the things you mention are essentially there. I also tried with something as dumb as predicting b from a with:

a = torch.Tensor([[[1],[2],[3],[1],[2],[3]]])
b = torch.Tensor([[2],[3],[1],[2],[3]]).transpose(0,1)
b = b.long()

After adapting the dimensions of the model. Even there, the loss doesn’t decrease to 0, so there is definitely something very wrong here, but for the life of me I can’t figure out what.

Did you ever figure out what it was?