Understanding time difference between finetuning and training the last layer with frozen weights

Hi,

I want to compare the performance of a dataset using finetuning and training the last layer with frozen weights of ResNet18. I’m using a dataset of The Simpsons, which has 20 classes, with 20.000 images, with between 300 and 1000 examples per class.

When I compare the accuracy, when finetunning I get 91% vs freeze and train that I get 62%, which looks good. However, the training time in finetunning is 17 min (using 4 NVIDIA M60 GPUs) and the training time in freeze and train is 16min. Shouldn’t the freeze and train example be much faster?

The code for finetuning is:

def finetune(dataloaders, model_name, sets, num_epochs, num_gpus, lr, momentum, lr_step, lr_epochs):
    num_class = len(dataloaders[sets[0]].dataset.class_to_idx)
    model_ft = models.__dict__[model_name](pretrained=True)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, num_class)
    if num_gpus > 1: 
        model_ft = nn.DataParallel(model_ft)
    model_ft = model_ft.cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = SGD(model_ft.parameters(), lr=lr, momentum=momentum)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=lr_epochs, gamma=lr_step)
    model_ft = train_model(dataloaders, model_ft, sets, criterion, optimizer, exp_lr_scheduler, num_epochs=num_epochs)
    return model_ft

the code for freeze and train is:

def freeze_and_train(dataloaders, model_name, sets, num_epochs, num_gpus, lr, momentum, lr_step, lr_epochs):
    num_class = len(dataloaders[sets[0]].dataset.class_to_idx)
    model_conv = models.__dict__[model_name](pretrained=True)
    for param in model_conv.parameters(): #params have requires_grad=True by default
        param.requires_grad = False
    num_ftrs = model_conv.fc.in_features
    model_conv.fc = nn.Linear(num_ftrs, num_class)
    if num_gpus > 1: 
        model_conv = nn.DataParallel(model_conv)
    model_conv = model_conv.cuda()
    criterion = nn.CrossEntropyLoss()
    if num_gpus > 1:
        params = model_conv.module.fc.parameters()
    else:
        params = model_conv.fc.parameters()
    optimizer = SGD(params, lr=lr, momentum=momentum)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=lr_epochs, gamma=lr_step)
    model_conv = train_model(dataloaders, model_conv, sets, criterion, optimizer, exp_lr_scheduler, num_epochs=num_epochs)
    return model_conv

The train_model function is based on the pytorch tranfer learning tutorial and can be found here. The complete code is here.

There is clearly something wrong, I performed the same experiment with CNTK and there is a big difference when freezing:

Using CNTK with simpsons dataset:
freeze:
Test accuracy is 0.763938618926
Process time 621.0402624607086 seconds

finetune:
Test accuracy is 0.959335038363
Process time 1635.1398813724518 seconds

Your code looks fine. I’ll take a look and try to come up with a repro by tomorrow.

I’m doing more experiments, the issue was when using 4 GPUs, with 1 GPU the result is expected
ResNet18 on 1GPU finetune:
Training complete in 21m 37s
Best val Acc: 0.953708

ResNet18 on 1GPU freeze:
Training complete in 8m 30s
Best val Acc: 0.692327

ResNet18 on 4GPU finetune:
Training complete in 10m 11s
Best val Acc: 0.954220

ResNet18 on 4GPU freeze:
Training complete in 9m 42s
Best val Acc: 0.686957

I’m curious to know if the issue is because the dataset is too small for 4 GPUs (20k images) or due to the small network ResNet18

Update:
I trained the same data in ResNet152:

ResNet152 on 4GPU finetune:
Training complete in 64m 55s
Best val Acc: 0.987468

ResNet152 on 4GPU freeze:
Training complete in 59m 18s
Best val Acc: 0.727110

It’s growing the number of GPUs what is making it slow, so the data IO is significant. I wonder what would be with a bigger dataset

I see. That makes sense. At the moment, DataParallel broadcasts parameters that are not modified, and has some other overhead around replicate and broadcast_coalesce. We are in the process to improve it. See some progress here: https://github.com/pytorch/pytorch/pull/4216.

1 Like