Updating model with more classes

I have trained resnet34 on my own data that has 53 classes. Every so often new classes come out and I need to retrain the model with the additional classes. For example, this week 9 new classes came out and I have to retrain my model with the additional 9 classes.

My approach is to load current model and state_dict for that model and then replace the fully connect layer with one that outputs 62 classes (in this case). Here is my code for resuming training based on my explanation:

        if os.path.isfile(args.resume):
            print(f"Loading Checkpoint from {args.resume}")
            checkpoint = torch.load(args.resume, map_location=device)
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            arch = args.arch
            state_dict = checkpoint['state_dict']
            optim = checkpoint['optimizer']

            if args.num_classes != checkpoint['num_classes']:
                model = ItemDetectorRes(checkpoint['num_classes'], arch)
                in_features = model.resnet.fc.in_features
                model.resnet.fc = nn.Linear(in_features, args.num_classes)
                args.start_epoch = 0
                best_acc = 0.0

The issue I am having is that it is not training. I cannot figure out why this would be the case. Any suggestions?

Do you see a constant loss or the same accuracy or what do you mean by “not training”?
The code looks alright, but could you make sure that the new linear layer is being updated by printing the gradients after calling loss.backward():


It does seem to be updating.

Here is an image of the training for a couple epochs:

When I retrain from scratch it trains just fine.

You might need to play around with the hyperparameters (e.g. lower the learning rate), which might help fine tuning a pretrained model.

I’ll give that a shot. It is still weird that it is not learning with a higher learning rate. Currently it is at 3e-3

Depending on the model, 3e-3 is not very low.
What was the learning rate you’ve used to train from scratch?

Same with it decaying by a factor of 10 every 10 epochs. I am training resnet34

If you are trying to finetune a model, I would generally recommend to lower your learning rate relative to the initial learning rate, since the majority of your parameters should already be in a good state.

Good to know. Although I think there must be an issue with my code since even at a very low learning rate it is still not learning. I will keep debugging and post a solution when I get to it