Model.eval() gives incorrect loss for model with batchnorm layers

Its pretty huge. with net.train(): 7e-8 with net.eval(): 99.3

Here is the truncated output (btw i fixed a bug in the original repro code, where the input batchsize did not match the output batchsize, now batchone = torch.ones[2, 1, 64, 64]… instead of [4, 1, 64, 64])

EDIT: If you do test the repro code, i suggest running it multiple times, the network sometimes explodes or does not converge as well depending on the random initialization. Validating using net.eval() on batchnorm2d is consistently worse than validation net.train() though. So there definitely seems to be a bug.

t loss1: 0.12320221960544586
t loss2: 0.09333723783493042
t loss1: 0.24439916014671326
t loss2: 0.5218815207481384
t loss1: 0.6079472899436951
t loss2: 3.729689598083496
t loss1: 2.112783908843994
t loss2: 28.2861328125
t loss1: 10.315664291381836
t loss2: 220.0238037109375
t loss1: 62.95467758178711
t loss2: 1715.936279296875
t loss1: 380.442138671875
t loss2: 12547.478515625
t loss1: 537.3038330078125
t loss2: 53827.6484375
t loss1: 10580.1064453125
t loss2: 484.6171875
t loss1: 625.4744873046875
t loss2: 174.572509765625
t loss1: 340.2962646484375
t loss2: 110.48974609375
t loss1: 199.84153747558594
t loss2: 81.0360107421875
t loss1: 124.49140167236328
t loss2: 59.60415267944336
t loss1: 80.25579071044922
t loss2: 44.34978485107422
t loss1: 53.18634796142578
t loss2: 33.68086624145508
t loss1: 36.14666748046875
t loss2: 26.259246826171875
t loss1: 25.240781784057617
t loss2: 20.230873107910156
t loss1: 17.960424423217773
t loss2: 15.522289276123047
t loss1: 12.958020210266113
t loss2: 11.859451293945312
....network converges....
t loss1: 8.753518159210216e-07
t loss2: 3.449472387728747e-06
t loss1: 7.535315944551257e-07
t loss2: 3.1326808311860077e-06
t loss1: 6.516746680063079e-07
t loss2: 2.851738372555701e-06
t loss1: 5.662100193148945e-07
t loss2: 2.603399707368226e-06
t loss1: 4.90573881961609e-07
t loss2: 2.3800134840712417e-06
t loss1: 4.2635625163711666e-07
t loss2: 2.181120635214029e-06
t loss1: 3.708064468810335e-07
t loss2: 2.0018892428197432e-06
t loss1: 3.229307878882537e-07
t loss2: 1.8403325157123618e-06
t loss1: 2.822781368649885e-07
t loss2: 1.6948088159551844e-06
t loss1: 2.4606848114672175e-07
t loss2: 1.5627991842848132e-06
t loss1: 2.1536402527999599e-07
t loss2: 1.4432781654249993e-06
t loss1: 1.8830129988600675e-07
t loss2: 1.3343917544261785e-06
t loss1: 1.6481742193263926e-07
t loss2: 1.2349712505965726e-06
t loss1: 1.4438909090586094e-07
t loss2: 1.1446034022810636e-06
t loss1: 1.2657532977300434e-07
t loss2: 1.0619536396916374e-06
t loss1: 1.1093613494495003e-07
t loss2: 9.862114893621765e-07
t loss1: 9.728184124924155e-08
t loss2: 9.167142138721829e-07
t loss1: 8.523475969468564e-08
t loss2: 8.52758375913254e-07
v loss1: 99.31019592285156
v loss2: 15.398092269897461
train v loss1: 7.480510788582251e-08
train vv loss2: 7.63321224894753e-07