Removing nn.BatchNorm2d() improved my model

TD;DR:
Removing nn.BatchNorm2d() after the conv2d layer improved my model and I don’t know why.

More info:
I’m using this implementation of Unet with a 256,257 image with 14 channels (The stft of a signal throgh 8 microphones) and I’m trying to categorize each pixel in the 256x257 to a certain angle (number between 0 and 12).

I’m using nn.CrossEntropyLoss. Before I removed the batchnorm the loss on the validation was going down in small steps with large volatility but the resulats weren’t good enough. After I googled a little bit why it’s happend I found this forum with a simillar problem saying removing batchnorm improved his model.

It improved the resulats on my model greatly but I have no idea why. Does someone have any explanation?

Because batchnorm is not the best normalization technique for your task. Batch norm tries to normalize all your images in a batch to same mean and std, whereas in case of segmentation you want different statistics for each image. Use InstanceNorm instead.

2 Likes

Additionally to batchnorm being a “bad fit” for your use case, you should check your current batch size, as batchnorm layers usually perform poorly with a small batch size.

3 Likes