Validation spike during partial fine-tuning


Hello everyone.
I am training a ResNet50 model (pre-trained on RadImageNet’s Pytorch’s weights) following a partial fine-tuning approach (only the last residual block and FC layer are trainable).
I am working with 525 training samples, using learning rate = 1e-4; batch size = 8; and Adam optimizer.
As you can see in the plot, in the first few epochs, the training accuracy increases to approximately 100%, while the loss efficiently decreases and is close to zero. However, during validation, there is a spike in the initial epochs.
Have you ever observed a similar behaviour with Transfer Learning? Can it be due to batch normalization layers? What can I try to avoid such behaviour? Is PyTorch’s ReduceLROnPlateau an option to consider? I am afraid that, if I use this technique, the LR becomes too small for the model to converge.
Thank you.