Model (resnet34) performs better without model.eval()

Hello,
I am using a ResNet34 and retrained fully to recognize sines in a voice message. With relatively short train set I was able to achieve promising results (almost 100% accuracy - I split the signal in windows and then take the mel spectrogram images for training and inference - pretty standard approach).
All this until I found out I was not calling model.eval(). I then added to my code and since then the model is not working at all or in any case the performance is really bad. (lots of false positives).

Now I read for inference we must set model.eval() but I don’t quite understand why not setting it actually works and so well. Is there some technique I can implement in those circumstances ? I see other people having similar problem but do not see anything that I can do using a predefined model.

Your help is really appreciated.

Based on your description it seems the running stats of the batchnorm layers might not properly capturing the mean and stddev of your validation dataset which then causes the bad results. We have a few topics where similar issues were discussed and e.g. changing the momentum of the batchnorm layers was suggested to try to smooth the running stats updates.

Thank you,
I have read few more articles on the subject… It is possible that my training set is not big enough.
However reading I have seen that other people in similar situations have not allowed the batch norm to fix the var and std during inference.
Batch normalization in 3 levels of understanding

I tried this code

def _freeze_norm_stats(net):
    try:
        for m in net.modules():
            if isinstance(m, nn.BatchNorm2d):
                #m.track_running_stats = False
                m.train()

    except ValueError:  
        print("error")
        return

and just before inference setting eval() and simply changed the batchnorm.

new_model_pred.eval()
_freeze_norm_stats(new_model_pred)

It seems to work well on the data in my use case.

Is there a drawback I do not see ?

Obviously with time I can build a wide enough set of training data that should fix this problem.

The drawback of using the batch stats only (during training and validation) would be the dependency on the used batch size. Since now you are normalizing the activations using the batch statistics you might need to keep the batch size equal as changing it might yield different results.

Thank you, learning a lot.

I am now running a test where I am changing the momentum similar code but now fixing momentum to a higher value. I think fixing the batch size is currently not too bad as step 1. I am building a PoC at this stage. Let’s see what the momentum test will tells me.

So the test still shows worse performance but when I run on the same data as training is lot better than before. Surely is about the variance on the input data from training to inference.

Is there a way to visualize the variance, std for pictures in a resnet34 network ? It will help me understand how the input data differs from one data set to another.

I generally apply the same pre-processing and using porch audio I apply same transformation so I would expect the data to be ranging in the same bulk.

Thanks,

You could use forward hook and store some stats of the intermediate activations you are interested in. This post explains how to use forward hooks on modules.