Dear community,
The problem is that the average epoch training loss of my network will converge nicely (say to .005 mean SmoothL1Loss), then after I save the checkpoint, when loading the average epoch training loss, it will be back to .05 (10 times worse). This repeats when loading from the next checkpoint: the model never picks up where it leaves off.
The data input has been checked, it’s all neat. Otherwise, the model wouldn’t converge this well anyway.
Because of all the preprocessing steps, the average epoch loss should be incredibly consistent (which it is within a run, but not between checkpoints).
Using another optimizer gives exactly the same problem. I’ve tried loading and saving in many different ways, exhausting all online resources, but to no avail.
I would call it an obscure PyTorch checkpoint problem.
What I use to save:
Save FusionNet checkpoint
torch.save({
'model_module_state_dict': FusionNet.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict()
}, f'{models_folder}FusionNet_snapshot{load_snapshot+iE+1}.tar')
What I use to load:
Initiate FusionNet
FusionNet = FusionGenerator(1,1,64)
if load_snapshot:
model_path = f'{models_folder}FusionNet_snapshot{load_snapshot}.tar'
checkpoint = torch.load(model_path, map_location=nn_handler_device)
check = FusionNet.load_state_dict(checkpoint['model_module_state_dict'])
FusionNet = nn.DataParallel(
FusionNet.to(device=nn_handler_device),
device_ids=gpu_device_ids
)
# Define optimizer and send to GPU
optimizer = torch.optim.Adam(FusionNet.parameters(), lr=lr, weight_decay=weight_decay)
if load_snapshot:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
optimizer_to(optimizer, torch.device(nn_handler_device))
# Define scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)
# Optional model snapshot loading
if load_snapshot:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
print(f'\tSnapshot of model {model} at epoch {load_snapshot} restored...')
print(f'\tUsing network to train on images from {data_set}/{data_subset}...')
# Make sure training mode is enabled
FusionNet.train()