Checkpointing for a dataparallel model

Does it make a difference if you checkpoint your model for retraining after model.eval() or model.train() loop?

It shouldn’t make any difference, as long as you don’t update the parameters in your validation loop.

This question seems to be unrelated to the topic, so do you have any issues using DataParallel?

I am trying to save a Dataparallel model but getting

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-17-e0020f404c69> in <module>()
     69             optimizer.zero_grad() # Zero gradients
     70             loss.backward() # Calculate gradients
---> 71             optimizer.step() # Update weights
     72             m.track_loss(loss)
     73             m.track_num_correct(preds, labels)

~/pytorch-1.0-p3/anaconda3/lib/python3.6/site-packages/torch/optim/adamw.py in step(self, closure)
     98 
     99                 # Decay the first and second moment running average coefficient
--> 100                 exp_avg.mul_(beta1).add_(1 - beta1, grad)
    101                 exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
    102                 if amsgrad:

RuntimeError: expected device cpu but got device cuda:0

The stack trace points to optimizer.step(), which is unrelated to saving the state_dict.

How did you pass the parameters to the optimizer?

So this worked for me which is little weird of a workflow:

# While saving checkpoint i.e. comment out while loading checkpoint
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        network = nn.DataParallel(network)
    
    
    network.to(device)
    optimizer = optim.AdamW(network.parameters(), lr=run.lr, weight_decay=run.weight_decay)
    
#     try:
    with open('check-point.pth', 'rb') as f:
        print('file opened')
        checkpoint = torch.load(f)
        print('file loaded')
        network.load_state_dict(checkpoint['model_state_dict'])
        print('network loaded')
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print('optimizer loaded')
        epoc = checkpoint['epoch']
        print(f'blah epoc: {epoc}')    
    # While loading checkpoint
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            network = nn.DataParallel(network)