About saving state_dict/checkpoint in a function

I am trying to implement the following function to save the model_state checkpoints:

def train_epoch(self):
    for epoch in tqdm.trange(self.epoch, self.max_epoch, desc='Train Epoch', ncols=100):
        self.epoch = epoch      # increments the epoch of Trainer
        checkpoint = {} # fixme: here checkpoint!!!
        # model_save_criteria = self.model_save_criteria
        self.train()
        if epoch % 1 == 0:
            self.validate(checkpoint) 
        checkpoint_latest = {
            'epoch': self.epoch,
            'arch': self.model.__class__.__name__,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optim.state_dict(),
            'model_save_criteria': self.model_save_criteria
        }
        checkpoint['checkpoint_latest'] = checkpoint_latest
        torch.save(checkpoint, self.model_pth)

Previously I did the same by just running a for loop:

train_states = {}
for epoch in range(max_epochs):
    running_loss = 0
    time_batch_start = time.time()
    model.train()
    for bIdx, sample in enumerate(train_loader):
    ...
    train...
    validation...
    train_states_latest = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'model_save_criteria': chosen_criteria}
    train_states['train_states_latest'] = train_states_latest
    torch.save(train_states, FILEPATH_MODEL_SAVE)

Are there ways to initiate the checkpoint={} and update it every loop? Or checkpoint={} in every epoch is fine since model itself is holding the state_dict(). Just I am overwriting the checkpoint each time.