Possible Issue with batch norm train/eval modes

I’m working on a unsupervised domain adaptation task, which consists of a synthetic and largely grayscale images (le’ts call this dataset A), as well real images (dataset B). The goal is to achieve an increase in performance on dataset B, by only using the labels from dataset A.

In the first step of my training process, I pre-train a resnet model on dataset A with batch norm. After pretraining, I evaluate accuracy on dataset B, and achieve an accuracy of around 50%.

Now comes the problem.

In my adaptation training code (def train()), I forward pass both the images from dataset A and dataset B in a single iteration. For example:

def train():
    for dataset_A_images, dataset_A_labels, dataset_B_images in dataloader:
        output_A = model(dataset_A_images)
        output_B = model(dataset_B_images)
        # in this fashion

I achieve a training accuracy of around 96%, which is good, but…
when I change the model to evaluation mode in my evaluation function (def evaluate()), the accuracy of the same dataset A that achieved a training accuracy of 96% in the train phase drops to 7%, with an absurdly large CE Loss.

What’s weird is that I’ve tried evaluating in train mode, and the performance still drops to less than 10%.

I’m thinking this is an issue/misuse with Batch norm, but I can’t seem to figure out why the performance drops so drastically.

To summarize, the accuracy of dataset A (which is not split into train/val) takes the following rollercoaster journey:
After pretrain: 95%
During main training code: 96%
Evaluation right after training 1 epoch: 7%
During next training epoch: 96% again.

I’ve been ripping my hair out over this issue. Thanks in advance to anyone who takes the time to read this long post :slight_smile:

PS: I can’t share certain aspects of the code due to security reasons, but I’m happy to provide a more detailed description if the problem is not clear enough.

Let me try to summarize the issue to check if I understood the problem completely.
You are pretraining on the synthetic datasetA and achieve a very good performance on A (95%) and around 50% accuracy for B during evaluation of both datasets.

Then you are using your train method to finetune your model on this task. While your model still achieves a good performance on datasetA if the model is in train mode, it drops significantly to 7% if you switch your model to eval. This procedure repeats for each epoch. Is this a correct understanding?

Are you shuffling both datasets? How large are your batch sizes in train for both datasets?
Have you tried to set the model to eval before passing datasetB in train to the model?
Also, have you checked the stats of both datasets, i.e. are they preprocessed in the same manner? Are the mean and std values completely different for both datasets?

Hey, thanks for the reply.

You seem to have understood the issue correctly. To answer some of your questions:

I do shuffle both train sets, and I use a batch size of 512. I use a custom joint dataset which returns 512 samples of each dataset. Both datasets also use the same transform (augmentation, normalization stats, etc.). The normalization stats were originally different, but I modified my code to use the same normalization after the first occurence of this problem. Unfortunately that didnt fix anything.

Lastly, I have indeed tried training in eval() mode just to see if how that affects things. When I set eval() during the forward/backward pass of both datasets, the CE Loss shoots up to around 5000, and the train accuracy drops to 10%.

I’m not really sure how to interpret this issue. It may be a code error on my part, but I do want to mention that I use the same trainer code for smaller datasets such as mnist, svhn and cifar, and those results have been successful. The new dataset im working with contain larger images, so really the only part of my code that has changed is the dataset and the model (which is a resnet-50 model with the fc layer replaced to fit the dataset classes).

From my understanding it sounds like datasetB is changing your running estimates in the BatchNorm layers such that datasetA performs bad.

Could you try to calculate the mean and std of both datasets using this approach? It would be interesting to see, if they are so different.

Hi,

I had these stats calculated earlier.
For dataset A:
mean: [0.8279, 0.8259, 0.8216], std: [0.2331, 0.2331, 0.2458]

For dataset B:
mean: [0.4148, 0.3956, 0.3756], std: [0.2735, 0.2675, 0.2709]

So there does seem to be a large shift between the mean of the two datasets.

Naturally, the question now becomes: How could I modify the code to get good performance on dataset B? Or even better, on both datasets? I would have expected leaving batch norm in train() mode during evaluation to yield decent results (since it calculates the mean + std for each batch, rather than using running estimates), but that doesn’t seem to change anything either.

In the mean time, I’ll try the whole procedure again with the above stats. I’ll leave a reply after I’ve done that.

One approach would be to use pass a few validation batches in train mode to adjust the running estimates.
I’m not completely sure, if that would be some kind of data leakage, but you might give it a try and see if it helps.

Although I haven’t been able to completely debug this issue yet, I have found a culprit of the overall issue. So in case anyone runs into a similar issue, here is my final response:

  1. I ran multiple tests in which I sequentially stripped the training procedure of one forward pass, until the point where I was only forward passing a single batch from a single dataset. It turns out one major issue (?) was that my pre-trained model was pre-trained with SGD momentum, while the main trainer was using Adam (with the same learning rate). For some reason, this was causing my train accuracy to plummet from 96% to 10% in the first epoch.

  2. I tried a variety of combinations with batch norm. It seems like freezing the batch norm running stats by changing the BN layers to eval mode yields the most stable results.

  3. While I haven’t been able to implement this yet, I think the best way to use batch norm with drastically different datasets is to keep a separate running mean/std for each dataset. While it is a little burdensome, I believe this is the most stable way of using batch norm in this case.

1 Like

I am currently running into a similar issue, however I see my effects during .train() mode instead (POST).

It is definitely a batchnorm issue as I have forced removed running means and it seems to solve the issue. Nevertheless, there doesn’t seem to be any clear solution to training on different distributions simultaneously (I was pre-training my discriminator using “fake only” and “real only” batches and it failed to get anywhere due to batchnorm running estimates).