Are the batch norm parameters updated every batch or every epoch?
They should be updated every batch.
>>> import torch
>>> b = torch.nn.BatchNorm2d(3)
>>> b.running_mean
tensor([0., 0., 0.])
>>> x = b(torch.randn(8, 3, 224, 224))
>>> b.running_mean
tensor([ 0.0002, -0.0002, -0.0002])
>>> x2 = b(torch.randn(8, 3, 224, 224))
>>> b.running_mean
tensor([-5.2687e-05, -4.1823e-04, -1.2951e-04])
>>>
1 Like