Checkpointing model not working at evaluation time

I’m using the technique of saving and loading the models as suggested in the Pytorch docs. I checkpoint the model whenever the current loss is lesser than the previous one. The problem here is that when I evaluate my validation data set, I’m getting the same accuracy with or without using checkpoint to load the model. Ideally the accuracy should vary right?

Here are my code snippets

def evaluate(dataloader, model, checkpoint=None):
  correct = 0
  tot_box = 0
  if checkpoint:
    model.load_state_dict(checkpoint['state'])
  model.eval()
  with torch.no_grad():
    for X, Y in dataloader:
      X, Y = X.to(device), Y.to(device)
      outputs = model(X)
      for batch in range(Y.shape[0]):
        boxes = decode_to_boxes(Y[batch], outputs[batch], grid_size, input_img_size[0])
        tot_box += boxes.shape[0]
        for box in boxes:
          iou_score = iou(box[0], box[1])
          if iou_score >=0.8:
            correct += 1
  
  accuracy = correct*100/tot_box
  return accuracy
checkpoint = {}
def fit(epochs=10, lr=0.0001):
  opt = optim.Adam(model.parameters(), lr)
  min_loss = 10
  for epoch in tqdm(range(epochs), total=epochs, unit="epoch"):
    for X_train, Y_train in train_loader:
      X_train, Y_train = X_train.to(device), Y_train.to(device)
      outputs = model(X_train)
      loss = loss_fn(outputs, Y_train)

      opt.zero_grad()
      loss.backward()
      opt.step()

      if loss.item() < min_loss:
        min_loss = loss.item()
        checkpoint['loss'] = (min_loss, epoch)
        checkpoint['state'] = model.state_dict()

    print('Epoch {}/{} loss {} min_loss {}'.format(epoch+1, epochs, round(loss.item(), 2), round(min_loss, 2)))

model = Network().to(device)
fit(40)
print(checkpoint['loss'])
-> (0.16597580909729004, 33)
evaluate(val_loader, model)
-> 45.06578947368421
evaluate(val_loader, model, checkpoint)
-> 45.06578947368421

I’m not able to understand why the accuracy is coming same. I hope I’m able to convey my question clearly.

Hi,

The problem is that when you get the state dict, it does not clone all the Tensors but gives you reference to them.
So when further training updates your model, it also updates the values stored in your checkpoint.
You can do: checkpoint['state'] = {name: t.clone() for name, t in model.state_dict()}

2 Likes

Hey thanks, it worked.
By the way, checkpoint['state'] = model.state_dict() this code used to work on my previous projects. In fact Pytorch also suggest the same code in its docs for checkpointing. But this code is not working anymore.
Is this checkpoint['state'] = {name: t.clone() for name, t in model.state_dict()} the new way of checkpointing?

It’s a bit hard to tell, as the plain assignment would work, if you store the checkpoint directly afterwards.
However, since you are temporarily storing the state_dict and are continuing the training, the checkpoint will also be updated, as explained by @albanD.

I think this behavior didn’t change, but in which version was your approach working?