Does it make a difference if you checkpoint your model for retraining after model.eval() or model.train() loop?
It shouldn’t make any difference, as long as you don’t update the parameters in your validation loop.
This question seems to be unrelated to the topic, so do you have any issues using DataParallel
?
I am trying to save a Dataparallel model but getting
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-17-e0020f404c69> in <module>()
69 optimizer.zero_grad() # Zero gradients
70 loss.backward() # Calculate gradients
---> 71 optimizer.step() # Update weights
72 m.track_loss(loss)
73 m.track_num_correct(preds, labels)
~/pytorch-1.0-p3/anaconda3/lib/python3.6/site-packages/torch/optim/adamw.py in step(self, closure)
98
99 # Decay the first and second moment running average coefficient
--> 100 exp_avg.mul_(beta1).add_(1 - beta1, grad)
101 exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
102 if amsgrad:
RuntimeError: expected device cpu but got device cuda:0
The stack trace points to optimizer.step()
, which is unrelated to saving the state_dict
.
How did you pass the parameters to the optimizer?
So this worked for me which is little weird of a workflow:
# While saving checkpoint i.e. comment out while loading checkpoint
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
network = nn.DataParallel(network)
network.to(device)
optimizer = optim.AdamW(network.parameters(), lr=run.lr, weight_decay=run.weight_decay)
# try:
with open('check-point.pth', 'rb') as f:
print('file opened')
checkpoint = torch.load(f)
print('file loaded')
network.load_state_dict(checkpoint['model_state_dict'])
print('network loaded')
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print('optimizer loaded')
epoc = checkpoint['epoch']
print(f'blah epoc: {epoc}')
# While loading checkpoint
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
network = nn.DataParallel(network)