I have trained resnet34 on my own data that has 53 classes. Every so often new classes come out and I need to retrain the model with the additional classes. For example, this week 9 new classes came out and I have to retrain my model with the additional 9 classes.
My approach is to load current model and state_dict for that model and then replace the fully connect layer with one that outputs 62 classes (in this case). Here is my code for resuming training based on my explanation:
if os.path.isfile(args.resume):
print(f"Loading Checkpoint from {args.resume}")
checkpoint = torch.load(args.resume, map_location=device)
args.start_epoch = checkpoint['epoch']
best_acc = checkpoint['best_acc']
arch = args.arch
state_dict = checkpoint['state_dict']
optim = checkpoint['optimizer']
if args.num_classes != checkpoint['num_classes']:
model = ItemDetectorRes(checkpoint['num_classes'], arch)
model.load_state_dict(state_dict)
in_features = model.resnet.fc.in_features
model.resnet.fc = nn.Linear(in_features, args.num_classes)
args.start_epoch = 0
best_acc = 0.0
model.to(device)
The issue I am having is that it is not training. I cannot figure out why this would be the case. Any suggestions?