after reload model training loss keep increasing

Im training a model and in the first training, no matter how many epochs (10,20,30…) i trained the model, everything looks great! the acc was increasing and the loss was decreasing as well. but when i load the saving model and retrain this model, the loss keep increasing every epoch.where could be the problem? I tried to save the optimizer’s state_dict and load them in the next training but i didnt work out. the loss is still increasing. I use Adam optimizer and crossentropy loss function.I’ve been bothered by this problem for several days. :frowning:

def train(
    epochs: int,
    batch_size: int,
    net: torch.nn.Module,
    trainDataLoader: DataLoader,
    testDataLoader: DataLoader,
    device: str,
    lossF: torch.nn.modules.loss._WeightedLoss,
    optimizer: torch.optim.Optimizer,
    save_path: str,
):
    for epoch in range(1, epochs + 1):
        processBar = tqdm(trainDataLoader, unit="step")
        net.train(True)
        for step, (train_seq, train_labels) in enumerate(processBar):
            train_seq = train_seq.to(device)
            train_labels = train_labels.to(device)
            optimizer.zero_grad()
            outputs = net(train_seq)
            loss = lossF(outputs, train_labels)
            predictions = torch.argmax(outputs, dim=1)
            accuracy = torch.sum(predictions == train_labels) / train_labels.shape[0]
            loss.backward()
            optimizer.step()
            processBar.set_description(
                "[%d/%d] Loss: %.4f, Acc: %.4f"
                % (epoch, epochs, loss.item(), accuracy.item())
            )
            if step == len(processBar) - 1:
                correct, total_loss = 0, 0
                net.train(False)
                with torch.no_grad():
                    for test_seq, test_labels in testDataLoader:
                        test_seq = test_seq.to(device)
                        test_labels = test_labels.to(device)
                        test_out = net(test_seq)
                        tloss = lossF(test_out, test_labels)
                        predictions = torch.argmax(test_out, dim=1)
                        total_loss += tloss
                        correct += torch.sum(predictions == test_labels)
                test_acc = correct / (batch_size * len(testDataLoader))
                test_loss = total_loss / len(testDataLoader)
                processBar.set_description(
                    "[%d/%d] Loss: %.4f, Acc: %.4f, Test Loss: %.4f, Test Acc: %.4f"
                    % (
                        epoch,
                        epochs,
                        loss.item(),
                        accuracy.item(),
                        test_loss.item(),
                        test_acc.item(),
                    )
                )
        model_save_path = os.path.join(save_path, "checkpoint.pt")
        with open(model_save_path, "wb") as f:
            torch.save(net.state_dict(), f)
        processBar.close()


def main():
    conf = config.AllConfig
    model_path = os.path.join(conf.save_path, "checkpoint.pt")
    model = TaxonClassifier.TaxonModel(
        vocab_size=conf.vocab_size,
        embedding_size=conf.embedding_size,
        hidden_size=conf.hidden_size,
        device=conf.device,
        max_len=conf.max_len,
        num_layers=conf.num_layers,
        num_class=conf.num_class,
        drop_out=conf.drop_prob,
    )
    model = model.to(device=conf.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=conf.lr)
    if os.path.exists(model_path) is True:
        print("Loading existing model state_dict......")
        checkpoint = torch.load(model_path, map_location=conf.device, weights_only=True)
        model.load_state_dict(checkpoint)
    else:
        print("No existing model state......")
    print("Loading Dict Files......")
    all_dict = Dataset.Dictionary(conf.KmerFilePath, conf.TaxonFilePath)
    print("Loading dataset......")
    all_dataset = Dataset.AllDataset(conf.DataPath, conf.max_len, all_dict, conf.kmer)
    train_dataloader = DataLoader(
        dataset=all_dataset.train_dataset,
        batch_size=conf.batch_size,
        shuffle=True,
        num_workers=4,
    )
    test_dataloader = DataLoader(
        dataset=all_dataset.test_dataset,
        batch_size=conf.batch_size,
        shuffle=False,
        num_workers=4,
    )
    lossF = torch.nn.CrossEntropyLoss()
    print("Start Training")
    train(
        epochs=conf.epoch,
        batch_size=conf.batch_size,
        net=model,
        trainDataLoader=train_dataloader,
        testDataLoader=test_dataloader,
        device=conf.device,
        lossF=lossF,
        optimizer=optimizer,
        save_path=conf.save_path,
    )

and here is the model

import torch
from torch import nn
from torch.nn import functional as F
from . import LSTMLayer, EmbeddingLayer


class TaxonModel(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embedding_size: int,
        hidden_size: int,
        device: str,
        max_len: int,
        num_layers: int,
        num_class: int,
        drop_out: float = 0.5,
    ):
        super(TaxonModel, self).__init__()
        self.num_layers = num_layers
        self.num_class = num_class
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.device = device
        self.max_len = max_len
        self.drop_out = drop_out
        self.seq_encoder = LSTMLayer.SeqEncoder(
            embedding_size, hidden_size, num_layers, drop_out
        )
        self.embedding = EmbeddingLayer.FullEmbedding(
            vocab_size, embedding_size, max_len, device, drop_out
        )
        # attention相关
        self.key_matrix = nn.Parameter(
            torch.Tensor(hidden_size * 2, hidden_size * 2), requires_grad=True
        )
        self.query = nn.Parameter(torch.Tensor(hidden_size * 2), requires_grad=True)
        # 初始化矩阵参数
        nn.init.uniform_(self.key_matrix, -0.1, 0.1)
        nn.init.uniform_(self.query, -0.1, 0.1)
        # 解码器,输出class
        self.decoder = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.GELU(),
            nn.Dropout(drop_out),
            nn.Linear(hidden_size, num_class),
        )

    def forward(self, x):
        x = self.embedding(x)  # [batch_size,seq_len,emb_size]
        x = x.permute(1, 0, 2)  # [seq_len,batch_size,emb_size]
        x = self.seq_encoder(x)  # x: [seq_len,batch_size,hidden_size*2]
        x = x.permute(1, 0, 2)  # [batch_size,seq_len,hidden_size*2]
        key = torch.tanh(
            torch.matmul(x, self.key_matrix)
        )  # [seq_len,batch_size,hidden_size*2]

        # torch.matmul(key,self.query)的结果为 [batch_size,seq_len]因为做的是内积
        # 再对第1维做softmax
        score = F.softmax(torch.matmul(key, self.query), dim=1).unsqueeze(
            -1
        )  # [batch_size,seq_len,1]

        x = x * score  # [batch_size,seq_len,hidden_size*2]
        x = torch.sum(x, dim=1)  # [batch_size,hidden_size*2]
        final_outputs = self.decoder(x)
        return final_outputs