BatchNorm weird behavior on Eval

I have simple example (1.1.0 pytorch)

import torch
print(torch.__version__)

x = torch.rand(1, 3, 2, 2)

for running_stats in [False, True]:
    for eval in [False, True]:
        print('running_stats:', running_stats, 'eval:', eval)
        bn = torch.nn.BatchNorm2d(3)
        bn.track_running_stats = running_stats
        if eval:
            bn.eval()
        print('\t', bn.running_mean)
        y1 = bn(x)
        print('\t', bn.running_mean)
        y2 = bn(x)
        print('\t', bn.running_mean)
        print('\t', (y1 - y2).abs().sum())

I expecting here that setting bn.eval() is enough for BN to return same results and keep same running mean and var. I also expect that different running mean and var will produce different results.

Output:

1.1.0
running_stats: False eval: False
	 tensor([0., 0., 0.])
	 tensor([0.0529, 0.0558, 0.0643])
	 tensor([0.1005, 0.1061, 0.1222])
	 tensor(0., grad_fn=<SumBackward0>)
running_stats: False eval: True
	 tensor([0., 0., 0.])
	 tensor([0.0529, 0.0558, 0.0643])
	 tensor([0.1005, 0.1061, 0.1222])
	 tensor(0., grad_fn=<SumBackward0>)
running_stats: True eval: False
	 tensor([0., 0., 0.])
	 tensor([0.0529, 0.0558, 0.0643])
	 tensor([0.1005, 0.1061, 0.1222])
	 tensor(0., grad_fn=<SumBackward0>)
running_stats: True eval: True
	 tensor([0., 0., 0.])
	 tensor([0., 0., 0.])
	 tensor([0., 0., 0.])
	 tensor(0., grad_fn=<SumBackward0>)

Issues:

  1. Running mean and var seems to have ABSOLUTELY no effect on output in this example
  2. Setting bn.eval() is not enough to keep same running mean and var

Hi,

You cannot change bn.track_running_stacts. It should be given as an argument when you create the BatchNorm() layer.

But still changing running mean has no effect on the output it seems that this E and Var somehow incapsulated and immutable inside C++ part

This shouldn’t be the case as shown here:

x = torch.rand(1, 3, 2, 2)
bn = nn.BatchNorm2d(3)
bn.eval()
print(bn.running_mean, bn.running_var)
out = bn(x)
print(out)

bn.running_mean = torch.tensor([100., 100., 100])
print(bn.running_mean, bn.running_var)
out = bn(x)
print(out)