I am currently conducting transfer learning experiments where I have a pre-trained backbone and I want to train a classification model on top of that backbone. In order to reduce the computational footprint of the whole operation, I compute the embeddings for my dataset prior to the training. Here is a pseudo-code of my transfer learning scheme :
train_dataloader, test_dataloader = load_data()
backbone = load_backbone()
train_embedding = []
train_labels = []
test_embeddings = []
test_labels = []
# Computing the embeddings once and for all
for X, y in train_dataloader:
train_embeddings.append(backbone(X).detach())
train_labels.append(y)
train_embeddings = torch.cat(train_embeddings, dim=0)
train_labels = torch.cat(train_labels, dim=0)
for X, y in test_data_loader:
test_embeddings.append(backbone(X).detach())
test_labels.append(y)
test_embeddings = torch.cat(test_embeddings, dim=0)
test_labels = torch.cat(test_labels, dim=0)
# Training a classification model on the embeddings
train_dataset = TensorDataset(train_embeddings, train_labels)
embed_train_dataloader = DataLoader(train_dataset, 256, shuffle=True)
test_dataset = TensorDataset(test_embeddings, test_labels)
embed_test_dataloader = DataLoader(test_dataset, 256, shuffle=True)
mlp_head = load_mlp()
train_mlp(mlp_head, embed_train_dataloader, embed_test_dataloader)
As my backbone is pre-trained and I don’t want to change its behavior, I could gain in performance by not computing grad with the two following context managers:
- toch.no_grad() : deactivate the gradient computation for the concerned module
- torch.inference_mode() : deactivate the gradient computation and optimize the concerned module.
As I am not performing training on the backbone, I was also tempted to set the backbone in eval mode like this backbone.eval
.
I run several experiments on the same dataset with the following configuration :
- no_grad : using the torch.no_grad() context manager around the embedding extraction
- inference : using the torch.inference_mode() context manager around the embedding extraction
- eval : put the backbone in eval mode using backbone.eval()
- no_grad + eval : put the backbone in eval mode + using the torch.no_grad() context manager
- inference + eval : put the backbone in eval mode + using the torch.inference_mode() context manager.
For each of these configurations, I kept the same training scheme for the subsequent MLP head. The only variation point between the experiments is, thus, the way I use my backbone model. The following figure shows the test accuracy obtained for each scheme (with 2 repetitions):
And those results are really confusing to me. As expected, no_grad and inference_mode are equivalent. But I don’t understand why the accuracy drops so much when using eval mode along with no_grad or inference_mode ?
What is the proper configuration for transfer learning? Only torch.no_grad() as shown in the majority of the tutorials on the subject? Or should I also set the backbone in eval mode ?
Thanks for your insight !