Dear readers,
For reasons I would like to train a model (on the GPU) and keep track of “the best” model weights (whenever loss is lowest). Sometimes with the same model I would like to switch to inference and load “the best” weights. Then for reasons when I start training again I would like to start where I left off and not at the “best weights”.
So my approach is the following:
- Saving state dict everytime new lowest loss (deepcopy)
- Save (train) state dict everytime when I switch to inference
- Load lowest loss state dict when doing inference
- Load train state dict when finished with inference
It does not seem to work, the best weights state dict keeps getting updated. Deepcopy is not the right approach? I would like to have it all in memory, not saved to a file, for reasons.
Here is it in code:
Trainer:
loss, model_metrics = self.model(self.batch)
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
self.total_number_of_epochs += 1
stopping_criterium = model_metrics["loss"]
if stopping_criterium < self.best_stopping_criterium:
self.best_stopping_criterium = stopping_criterium
self.best_loss = model_metrics['loss']
# Save state dict to disk and to memory
self.model.save(self.model_weight_path)
self.best_weights = copy.deepcopy(self.model.model.state_dict())
Inference:
def inference_single_batch(self, image_batch):
"""
Inference for a small number of images, should fit completely in (GPU) memory
"""
# Save current weights (to avoid a infinite loop if predictions are ask every small interval)
self.train_state_dict = copy.deepcopy(self.model.model.state_dict())
# Load best model weigths if they exists
if self.best_weights is not None:
self.model.model.load_state_dict(self.best_weights)
# Offload training data if they are on the GPU
self.training_data_cuda_to_cpu()
# Set model in evaluate mode if necessary
if self.model.model.training:
self.model.model.eval()
# Perform inference
image_batch = {Data.IMAGE: torch.from_numpy(np.concatenate(image_batch)).to(self.device)}
results, _ = self.inference.batch(image_batch)
# Reload the train weights
self.model.model.load_state_dict(self.train_state_dict)
I hope someone can help me,
Regards,
Frank