Cannot continue training previous saved transformer model with AdamW optimizer

Hello everyone,

I am trying to train a Transformer with AdamW for a machine translation task. Unfortunately, after around 10 epochs, my GPU ran out of memory so I’d like to save the model down and continue training it by rerunning my program. However, even when I already loaded the model parameters and the optimizer state dicts, the loss was back to its initial value every time I started training it.

Here is the code of the initializing process:

 transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)
    optimizer = torch.optim.Adam(transformer.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.98), eps=1e-9)

    if SAVE_STATE:
        try:
            checkpoint = torch.load(PATH, map_location=DEVICE)
            print("Checkpoint loaded successfully!")
            transformer.load_state_dict(checkpoint["model_state_dict"], strict=True)
            transformer = transformer.to(DEVICE)
            transformer.train()
            print("Loading transformer state dict successfully!")
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            print("Loading optimizer state dict successfully!")
        except FileNotFoundError:
            print("No checkpoint found. Starting training from scratch.")
        except KeyError as e:
            print(f"Missing key in checkpoint: {e}")
        except Exception as e:
            print(f"An error occurred while loading the checkpoint: {e}")
    else:
        transformer = transformer.to(DEVICE)

And here is the code of the model:

class Seq2SeqTransformer(nn.Module):
    def __init__(
            self,
            n_encoder_layer: int,
            n_decoder_layer: int,
            n_embed: int,
            n_head: int,
            src_vocab_size: int,
            tgt_vocab_size: int,
            dff: int = 512,
            dropout: float = 0.1
    ):
        super().__init__()
        self.transformer = nn.Transformer(
            d_model=n_embed,
            nhead=n_head,
            num_encoder_layers=n_encoder_layer,
            num_decoder_layers=n_decoder_layer,
            dim_feedforward=dff,
            dropout=dropout,
        )
        self.generator = nn.Linear(n_embed, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, n_embed)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, n_embed)
        self.positional_encoding = PositionalEncoding(n_embed, 1024, dropout)

    def forward(
            self,
            src,
            tgt,
            src_mask,
            tgt_mask,
            src_padding_mask,
            tgt_padding_mask,
            memory_key_padding_mask,
    ):
        src_emb = self.positional_encoding(self.src_tok_emb(src)) # (T, B, C)
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(tgt)) # (T, B, C)
        outs = self.transformer(
            src_emb,
            tgt_emb,
            src_mask,
            tgt_mask,
            None,
            src_padding_mask,
            tgt_padding_mask,
            memory_key_padding_mask,
        )
        return self.generator(outs) # (T, B, new_C)

    def encode(self, src, src_mask):
        return self.transformer.encoder(self.positional_encoding(self.src_tok_emb(src)), src_mask)

    def decode(self, tgt, memory, tgt_mask):
        return self.transformer.decoder(self.positional_encoding(self.tgt_tok_emb(tgt)), memory, tgt_mask)

I believe it’s just a simple model and no background processes prevent the model from loading the old state. I already checked the model state_dict after loading up and it’s working fine. The main problem is the losses, it’s not coming down, and seems like the model or the optimizer is not loading up.

Here is a sample log of the training process (the tqdm progress bar looks messy bc I’m still testing it so please don’t mind it):

First training:

Start training...
  0%|                                       | 2/2079 [00:00<05:25,  6.38it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 136.17it/s]
Epoch: 1, Train loss: 6.408, Val loss: 6.446
  0%|                                       | 2/2079 [00:00<00:23, 89.34it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 142.82it/s]
Epoch: 2, Train loss: 6.367, Val loss: 6.443
  0%|                                      | 2/2079 [00:00<00:20, 100.72it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 138.72it/s]
Epoch: 3, Train loss: 6.331, Val loss: 6.429
  0%|                                       | 2/2079 [00:00<00:23, 89.82it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 139.64it/s]
Epoch: 4, Train loss: 6.275, Val loss: 6.418
  0%|                                       | 2/2079 [00:00<00:23, 90.18it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 137.44it/s]
Epoch: 5, Train loss: 6.231, Val loss: 6.407
  0%|                                       | 2/2079 [00:00<00:22, 91.01it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 136.57it/s]
Epoch: 6, Train loss: 6.177, Val loss: 6.384
  0%|                                       | 2/2079 [00:00<00:22, 92.38it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 141.61it/s]
Epoch: 7, Train loss: 6.156, Val loss: 6.354
  0%|                                       | 2/2079 [00:00<00:22, 93.42it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 139.16it/s]
Epoch: 8, Train loss: 6.105, Val loss: 6.329
  0%|                                       | 2/2079 [00:00<00:22, 93.68it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 141.56it/s]
Epoch: 9, Train loss: 6.085, Val loss: 6.311
  0%|                                       | 2/2079 [00:00<00:22, 90.79it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 134.43it/s]
Epoch: 10, Train loss: 6.048, Val loss: 6.303
  0%|                                       | 2/2079 [00:00<00:22, 90.45it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 142.22it/s]
Epoch: 11, Train loss: 6.021, Val loss: 6.284
  0%|                                       | 2/2079 [00:00<00:22, 91.14it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 140.51it/s]
Epoch: 12, Train loss: 5.995, Val loss: 6.260
  0%|                                       | 2/2079 [00:00<00:22, 94.00it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 140.91it/s]
Epoch: 13, Train loss: 5.966, Val loss: 6.247
  0%|                                       | 2/2079 [00:00<00:22, 90.39it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 140.79it/s]
Epoch: 14, Train loss: 5.943, Val loss: 6.235
  0%|                                       | 2/2079 [00:00<00:22, 92.67it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 142.17it/s]
Epoch: 15, Train loss: 5.920, Val loss: 6.215

Second training:

Start training...
  0%|                                       | 2/2079 [00:00<06:03,  5.72it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 136.98it/s]
Epoch: 1, Train loss: 6.421, Val loss: 6.451
  0%|                                       | 2/2079 [00:00<00:23, 90.06it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 143.49it/s]
Epoch: 2, Train loss: 6.379, Val loss: 6.448
  0%|                                      | 2/2079 [00:00<00:20, 102.31it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 137.98it/s]
Epoch: 3, Train loss: 6.341, Val loss: 6.433
  0%|                                       | 2/2079 [00:00<00:22, 92.47it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 142.71it/s]
Epoch: 4, Train loss: 6.283, Val loss: 6.423
  0%|                                       | 2/2079 [00:00<00:22, 91.21it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 139.05it/s]
Epoch: 5, Train loss: 6.238, Val loss: 6.411
  0%|                                       | 2/2079 [00:00<00:22, 93.44it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 143.37it/s]
Epoch: 6, Train loss: 6.183, Val loss: 6.388
  0%|                                       | 2/2079 [00:00<00:22, 93.45it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 140.80it/s]
Epoch: 7, Train loss: 6.161, Val loss: 6.358
  0%|                                       | 2/2079 [00:00<00:22, 94.32it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 144.01it/s]
Epoch: 8, Train loss: 6.110, Val loss: 6.333
  0%|                                       | 2/2079 [00:00<00:22, 94.20it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 143.04it/s]
Epoch: 9, Train loss: 6.089, Val loss: 6.314
  0%|                                       | 2/2079 [00:00<00:22, 92.66it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 140.91it/s]
Epoch: 10, Train loss: 6.051, Val loss: 6.307
  0%|                                       | 2/2079 [00:00<00:22, 91.59it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 142.22it/s]
Epoch: 11, Train loss: 6.025, Val loss: 6.287
  0%|                                       | 2/2079 [00:00<00:22, 91.74it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 140.68it/s]
Epoch: 12, Train loss: 5.998, Val loss: 6.262
  0%|                                       | 2/2079 [00:00<00:22, 93.57it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 139.26it/s]
Epoch: 13, Train loss: 5.970, Val loss: 6.249
  0%|                                       | 2/2079 [00:00<00:22, 92.84it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 142.15it/s]
Epoch: 14, Train loss: 5.946, Val loss: 6.237
  0%|                                       | 2/2079 [00:00<00:22, 91.88it/s]
  8%|███▏                                    | 2/25 [00:00<00:00, 141.52it/s]
Epoch: 15, Train loss: 5.924, Val loss: 6.216

Can anyone tell me which part I am doing wrong? I’d be grateful if you could give me some advice for this code.
Thanks a lot!