Proper way to configure a backbone for transfer learning

I am currently conducting transfer learning experiments where I have a pre-trained backbone and I want to train a classification model on top of that backbone. In order to reduce the computational footprint of the whole operation, I compute the embeddings for my dataset prior to the training. Here is a pseudo-code of my transfer learning scheme :

train_dataloader, test_dataloader = load_data()
backbone = load_backbone()

train_embedding = []
train_labels = []

test_embeddings = []
test_labels = []

# Computing the embeddings once and for all
for X, y in train_dataloader:
    train_embeddings.append(backbone(X).detach())
    train_labels.append(y)

train_embeddings = torch.cat(train_embeddings, dim=0)
train_labels = torch.cat(train_labels, dim=0)


for X, y in test_data_loader:
    test_embeddings.append(backbone(X).detach())
    test_labels.append(y)

test_embeddings = torch.cat(test_embeddings, dim=0)
test_labels = torch.cat(test_labels, dim=0)


# Training a classification model on the embeddings
train_dataset = TensorDataset(train_embeddings, train_labels)
embed_train_dataloader = DataLoader(train_dataset, 256, shuffle=True)

test_dataset = TensorDataset(test_embeddings, test_labels)
embed_test_dataloader = DataLoader(test_dataset, 256, shuffle=True)

mlp_head = load_mlp()
train_mlp(mlp_head, embed_train_dataloader, embed_test_dataloader)

As my backbone is pre-trained and I don’t want to change its behavior, I could gain in performance by not computing grad with the two following context managers:

  • toch.no_grad() : deactivate the gradient computation for the concerned module
  • torch.inference_mode() : deactivate the gradient computation and optimize the concerned module.

As I am not performing training on the backbone, I was also tempted to set the backbone in eval mode like this backbone.eval.

I run several experiments on the same dataset with the following configuration :

  • no_grad : using the torch.no_grad() context manager around the embedding extraction
  • inference : using the torch.inference_mode() context manager around the embedding extraction
  • eval : put the backbone in eval mode using backbone.eval()
  • no_grad + eval : put the backbone in eval mode + using the torch.no_grad() context manager
  • inference + eval : put the backbone in eval mode + using the torch.inference_mode() context manager.

For each of these configurations, I kept the same training scheme for the subsequent MLP head. The only variation point between the experiments is, thus, the way I use my backbone model. The following figure shows the test accuracy obtained for each scheme (with 2 repetitions):

results

And those results are really confusing to me. As expected, no_grad and inference_mode are equivalent. But I don’t understand why the accuracy drops so much when using eval mode along with no_grad or inference_mode ?

What is the proper configuration for transfer learning? Only torch.no_grad() as shown in the majority of the tutorials on the subject? Or should I also set the backbone in eval mode ?

Thanks for your insight !

I come back to you with more information concerning this issue as it bugs me. For me, passing a model in inference mode (or no_grad mode) should not have any impact on the output of the model. In my understanding, they just impact performance and the gradient computation of course.

To assess that, I performed an experiment. I took a pre-trained model and compute the embeddings of a dataset using 3 setup :

  • The model only in eval mode
  • The model is in eval mode + inference mode
  • The model is just in eval mode

I made 2 runs and for each setup. I computed the mean of the absolute value of the embeddings to see if embeddings vary depending on the mode of the model. Here are the results :

Means for eval mode : 0.80259293 and 0.80259293
Means for inference mode : 0.79772127 and 0.797736
Means for inference + eval mode 0.8016215 and 0.8016215

As expected, if just put in inference mode, the randomness of the dropout layers lead to slightly different results between runs.

But I cannot find any explanation about the difference between eval mode only and eval mode + inference mode.

For me the inference mode should not impact the result obtained and should behave exactly like the eval mode only experiment.

If any of you have hypotheses I would be glad to test them!