A bug with Module.train()

Hello!
I’m working on a task of binary classification.
I use Dropout and BatchNormalization layers in my network. So I use net.train(False) to evaluate my network.
However, I find something strange when I try to calculate the accuracy.
After training, I set net.train(False). And I find the accuracy on train batches is around 80% despite that the loss was very low.
Then I change it to net.train(True), and the accuracy on train batches rose to nearly 100%.
After that, I change it back to net.train(False). And surprisingly, the accuracy is still over 99%.
Can anyone explain this? It drives me crazy!

p.s. Here is my code for evaluate:

fnet.train(False)
total = 0
correct = 0
for data in train_batch:
images, labels = torch.cuda.FloatTensor(data[0]), torch.cuda.LongTensor(data[1]).view(-1)
outputs = fnet(Variable(images))
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
print(‘Test accuracy of the network: %.5f %%’ % (
100 * correct / total))

print(fnet.training)

Here are the screen cuts.

The BatchNorm layers got updated when you switched back to net.train(True) with the statistics of the test set. After calling net.train(False) again, the BatchNorm is still using the updated running_mean and running_var.
This makes sense, if your training and test set differ in their statistics.

Hello ptrblck!
Thanks for your answer! I understand it to some extent.
Then what should do to evaluate my network?
thx!

Btw, in your code snippets you use the training data for evaluating the performance?
Doing this you will only get an estimate of your resubstitution error, which does not tell anything about the model’s performance on new unseen data.

Back to the question.
This is a known issue, if your data statistics are very shaky. You could try to increase the momentum parameter in BatchNorm layers, so that the running stats will be weighted higher.

Thx. I got it!
I use training data to see whether there is overfitting. Then I found this problem.
Thank you so much!

Uhh it should be decreasing the momentum because BN’s momentum arg is actually 1-momentum due to unfortunate historical reasons.

Thanks for the correction on this and sorry for the confusion!
I misunderstood the documentation saying

momentum – the value used for the running_mean and running_var computation.

Sounds like momentum is the value used to scale the so far accumulated running_* params.
With a default value of 0.1 my understanding doesn’t really make sense though.

Yeah, it is quite confusing. I have a pending PR to further clarify that :slight_smile:

1 Like