Trying to understand an Issue with batch Normalization

Hi everybody,

I am encountering an issue while using batch normalization in my model. I am using a CNN.

In order to debug my network, I took a single batch of 10 samples. here each sample is a picture of (325, 40) so each tensor has dimension (1, 1, 325, 40). and my full batch is a tensor of size (10, 1, 325,40).

In my training loop I am using training mode: model.train() and for validation model.eval().

I also know that while training, the mean average used with batch normalization is updated and since I am using batch of 10 I am using a momentum of 0.1 for my batch norm.

At first I can see that both my training loss and my validation loss are decreasing and are of the same scale ( training_loss = O(validation_loss)).

But after a while my validation loss starts diverging.
Once again, here my validation set and my training set are the same dataset.

I think this is due to the compute of my moving average or the learning of parameters gamma and beta. (parameters learnt for best normalization)

Would you have an idea of what might be the problem here ?

PS:

1/ on my validation set, when I am using model.train() instead of model.eval() I obtain a training_loss and validation_loss really close to each other.

2/ I also tried to use bigger batches (batches of 50) but the same problem occurs.

thanks !