The high validation loss is due to the wrong estimates of the running stats.
Since you are feeding a constant tensor (batchone
: mean=1, std=0) and a random tensor (batchtwo
: mean~=0, std~=1), the running estimates will be shaky and wrong for both inputs.
During training the current batch stats will be used to compute the output, so that the model might converge.
However, during evaluation the batchnorm layer tries to normalize both inputs with skewed running estimates, which yields the high loss values.
Usually we assume that all inputs are from the same domain and thus have approx. the same statistics.
If you set track_running_stats=False
in your BatchNorm layer, the batch statistics will also be used during evaluation, which will reduce the eval loss significantly.