Result of NN dependent on the number of instances passed

Hey guys,

I have a question. I don’t want to post any code right now but I can if you need it to help me.

I made a conv nn. No fc layers etc. Just conv (I don’t know if that is relevant). I train the network on 10 instances (mapping of some 20x20 pixels images onto another 20x20 pixels) and I run a batch of 10 instances at once.

The error of the network is great. So I run all the 10 input instances from the training set through the network (one batch of 10 instances). The results are what I expected. Then I run only one instance through the network (1 batch with 1 instance) and the result is way different, completely wrong from when the same instance was passed through the network in a batch of 10.

Is this normal or is something wrong? I’m completely flabbergasted.

EDIT: I’m using batch norm if that matters.

BatchNorm won’t work with a single sample. If the training is good and you would like to pass a single image into the net, you could set the model to eval with model.eval(). Otherwise if you would like to train the model, you should remove the BatchNorm layers and use InstanceNorm for example.

Thank you ptrblck, that did the trick :).

Although I’m wondering now how PyTorch implements batch norm. From what I have learnt two parameters are trainable in batch norm: scale and bias applied to the normalized vector. But also batch’s mean and standard deviation are sometimes trainable. If they indeed are trainable then I don’t think it made any difference whether we run a single or many instances through the network. So the question is: does PyTroch, when using batch norm, also trains mean and standard deviation?

The mean and std are updated in each forward pass. The momentum terms defines how of the new value will be added to (1 - momentum) of the old “running” stats.

From the docs:

This momentum argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is x̂_new=(1−momentum)×x̂ +momemtum×x_t, where x̂ is the estimated statistic and x_t is the new observed value.

Since this implementation is a moving average, using single samples might be problematic, since the running stats will change a lot.
Usually, BatchNorm yields good results using a batch size of ~ 64.

Ok, so what happens when I switch into eval mode? Mean/variance don’t get updated then?

Exactly, during the evaluation you often pass single images through your net and don’t want to update any parameters or buffers.
The running stats as well as beta and gamma are just applied then.

Same goes for Dropout. During training the units are dropped randomly depending on the specified drop probability. During evaluation all units are used.

One more thing regarding momentum. So if I’m correctly understanding that, the running mean and stddev are updated with each batch, and each new batch will contribute 10% (if momentum is 0.1) to the current mean and stddev. If that is the case then after all batches have gone through the network then the last batch had 10% of “contribution” to the final mean/stddev whereas the first batch that went through the network has been “washed out”. Is my understanding correct?

You are right.
That’s why it’s useful to use drop_last in the DataLoader if your last batch is really small, e.g. one single image.

I see. But still that doesn’t solve the problem of “washing out” those unfortunate batches that had this privilege to go first :). Wouldn’t simply averaging all batches, each with equal contribution, behave in the most optimal way?

You can calculate the mean and std of the input data, sure.
That’s usually done for normalization.

However, BatchNorm is used between layers and normalizes the activations.
So iterating your whole dataset and update each BatchNorm layer separately will be very costly.
You would have to update layer by layer the running stats. After a weight update the stats will change and you would have to do it again.

In practice the momentum approach works quite well.

Got it. Thank you ptrblck a bunch for the discussion.