I am using a Multi-instance learning approach for histopathology modeling. To fit all the patches coming from a single whole slide image (100kx100k pixel), I am using a batch size of 1. But in each batch, passing 64-128 patches coming from a whole slide image so batch norm should not mess up. As I am using pre-trained Resnet and the mean/std distribution of medical data will be very different from the imagenet, I have bumped up the momentum to 0.9 and re-initialized batch norm parameters (mean to zero and std to one). The model training is fine, but validation is highly unstable. Based on my experiments and evaluation done using both train and eval mode. I suspect it to be because of track running stats set to True. As I am using the train running mean/std in eval mode and in each new batch picking up a new whole slide image (64-128 patches), leading to a change in distribution at a fast speed resulting in poor validation performance. I can’t go with group norm/ instance norm as the community has not started sharing pre-trained models with these normalization layers. So:
- Is it wrong to set track running mean/std to False in the eval stage for cases like these? Are there any paper/example which set it to False?
- Is running validation in train mode technically wrong? I will be just updating the running mean and std with a momentum of 0.9.
- I have around 300 whole slide images for training, and keeping momentum around 0.1 degrades the performance. One option is to further experiment with this set-up by re-initializing batch norm parameters and to keep momentum=0.1. But this may work and start giving me stable performance even in the evaluation mode as momentum is low, and the model gets trained with a slow changing mean/std (sort of global mean/std instead of a single whole slide mean/std).
Any suggestions are highly appreciated? If the day is in my favor, I get pretty good training and validation accuracy (with track running stats False) but I am hunting for a more stable/sustainable solution.