Use single model for both training and inference, best way?

Dear readers,

For reasons I would like to train a model (on the GPU) and keep track of “the best” model weights (whenever loss is lowest). Sometimes with the same model I would like to switch to inference and load “the best” weights. Then for reasons when I start training again I would like to start where I left off and not at the “best weights”.

So my approach is the following:

  • Saving state dict everytime new lowest loss (deepcopy)
  • Save (train) state dict everytime when I switch to inference
  • Load lowest loss state dict when doing inference
  • Load train state dict when finished with inference

It does not seem to work, the best weights state dict keeps getting updated. Deepcopy is not the right approach? I would like to have it all in memory, not saved to a file, for reasons.

Here is it in code:

Trainer:

        loss, model_metrics = self.model(self.batch)
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        self.total_number_of_epochs += 1

        stopping_criterium = model_metrics["loss"]

        if stopping_criterium < self.best_stopping_criterium:
            self.best_stopping_criterium = stopping_criterium
            self.best_loss = model_metrics['loss']

            # Save state dict to disk and to memory
            self.model.save(self.model_weight_path)
            self.best_weights = copy.deepcopy(self.model.model.state_dict())

Inference:

def inference_single_batch(self, image_batch):
    """
    Inference for a small number of images, should fit completely in (GPU) memory
    """
    # Save current weights (to avoid a infinite loop if predictions are ask every small interval)
    self.train_state_dict = copy.deepcopy(self.model.model.state_dict())

    # Load best model weigths if they exists
    if self.best_weights is not None:
        self.model.model.load_state_dict(self.best_weights)

    # Offload training data if they are on the GPU
    self.training_data_cuda_to_cpu()

    # Set model in evaluate mode if necessary
    if self.model.model.training:
        self.model.model.eval()

    # Perform inference
    image_batch = {Data.IMAGE: torch.from_numpy(np.concatenate(image_batch)).to(self.device)}
    results, _ = self.inference.batch(image_batch)

    # Reload the train weights
    self.model.model.load_state_dict(self.train_state_dict)

I hope someone can help me,

Regards,
Frank

This shouldn’t be the case and deepcopy should work as seen here:

best_loss = 1000000.
model = nn.Linear(1, 1)
best_state = copy.deepcopy(model.state_dict())

optimizer = torch.optim.Adam(model.parameters(), lr=1.)
criterion = nn.MSELoss()

x = torch.randn(1, 1)
target = torch.randn_like(x)

for epoch in range(10):
    out = model(x)
    loss = criterion(out, target)
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    
    if loss.item() < best_loss:
        print("saving new state_dict")
        print("best_state before\n{}".format(best_state))
        best_loss = loss.item()
        best_state = copy.deepcopy(model.state_dict())
        print('best_state now\n{}'.format(best_state))

I keep asking stupid late night questions, sorry for that.

Thanks for conformation that it should work. Btw, I did move the saving model before the loss.backward and optimizer stuff, since I thought that is where the weights are updated?

Anyway, the real problem was that I did some augmentation live on the GPU (flipping and stuff) and that when updating the batch I updated the targets but not the data (images) itself. So after “live” annotations (updating by batch data) my result was getting worse and worse and somehow I concluded that since the model for inference with the “best weights” was giving worse results, it had to be from the weights not being right. I underestimated how fast it overfits and how very quickly a worse model can have a lower loss.

Long story short, not sure if you would like to keep my question here, feel free to delete and thanks for the answer (and all other indirect help that I get from you by googling or searching this forum)!