Should I save optimizer state_dict?

HI guys I’ve been doing Object detection using Retinanet net ,I would like to see if my network is learning well ,So i intentionally use training set data for predictions ,However,the result turns out totally terrible,The prediction of my output have way too many boxes ,In fact,it covers the whole images,I suspect that the way that I save or load state_dict was wrong …
Here’s what I do…

training_iterations = 0
avg_loss = 0.
model = Retina_Net(args.num_class)
#load resnet50 weights and some initializations
weights = torch.load(join("weights","retinet.pth"))
model.load_state_dict(weights) 
model.cuda()
loss_function = Focal_loss(args.num_class)
optimizer  = optim.AdamW(model.parameters(),lr = args.lr,weight_decay=args.weight_decay)

def train(epoch,mode=True,beta=0.99):
    model.train(mode)
    global training_iterations
    global avg_loss
    for num_batch,(inputs,loc_targets,cls_targets) in enumerate(training_set_loader):
        inputs = inputs.cuda()
        loc_targets = loc_targets.cuda()
        cls_targets = cls_targets.cuda()
        optimizer.zero_grad()
        #I use cosine form warmup and annealing for momentum and lr
        optimizer.param_groups[0]["lr"] = lr_distr[training_iterations]
        optimizer.param_groups[0]["betas"] = (mom_distr[training_iterations].item(),0.999)
        loc_preds,cls_preds = model(inputs)
        loss = loss_function(loc_preds, loc_targets, cls_preds, cls_targets)
        #back prop
        loss.backward()
        #update parameters
        optimizer.step()
        avg_loss = loss.item()*(1-beta)+avg_loss*beta
        smooth_loss = avg_loss/(1-(beta**(training_iterations+1)))
        training_iterations+=1
        training_loss.append(smooth_loss)
        
    global best_loss
    model.eval()    
    if smooth_loss<best_loss:
        weight_dir = join("weights",args.files.split(".")[0])
        if not exists(weight_dir):
            mkdir(weight_dir) 
        torch.save(model.state_dict(),join(weight_dir,"Epoch_{
}_{:.3f}.pth".format(epoch,smooth_loss)))
        best_loss = smooth_loss
for epoch in range(1,args.epochs+1):
    train(epoch)

And predicting:

model = Retina_Net()
losses=[]
#sort and use the one with lowest loss 
for weight in weights:
    losses.append((weight,float(weight.split("/")[-1].split("_")[-1].split(".pth")[0])))
losses.sort(key= lambda x:x[1])

model.load_state_dict(torch.load(losses[0][0]))

#turn the model into test mode 
model.eval()

#then predict

I would like to know if there’s anything I do wrong or missing.
Thanks in advance!!!

Could you explain your code a bit, please?
How did you store the state_dict and how are you processing weights?

I forget to mention how I saved the weight ,I’ve edit that in the train function,by adding

torch.save(model.state_dict(),join(weight_dir,"Epoch_{
}_{:.3f}.pth".format(epoch,smooth_loss)))

So it looks like:

global best_loss
    model.eval()    
    if smooth_loss<best_loss:
        weight_dir = join("weights",args.files.split(".")[0])
        if not exists(weight_dir):
            mkdir(weight_dir) 
        torch.save(model.state_dict(),join(weight_dir,"Epoch_{
}_{:.3f}.pth".format(epoch,smooth_loss)))
        best_loss = smooth_loss

Basically I save every epochs weight cause I would predict the data in the training set to check if model learns well.I turn the model in to eval model.
And then I do prediction ,I didn’t process the weight What I do is simply do :

model = Retina_Net()
losses=[]
#sort and use the one with lowest loss 
for weight in weights:
    losses.append((weight,float(weight.split("/")[-1].split("_")[-1].split(".pth")[0])))
losses.sort(key= lambda x:x[1])

model.load_state_dict(torch.load(losses[0][0]))

#turn the model into test mode 
model.eval()

#then predict

If I miss any information ,please let me know,Thanks!

Thanks for the update.
I cannot find any obvious error, but I would recommend to try to narrow down the error by storing just one state_dict after e.g. 10 epochs, restore it, and check the result.
If you are able to save and load a single state_dict giving you approx. the same model performance, the error might be in your current storing mechanism.
Otherwise we would need to look into the model or training routine.

1 Like

Okay ,Thanks for your suggestion!