Hi!
Did you manage to solve this?
I’m trying to do the same thing, training with fixed mean/var for the batchnorm layer.
I set all batchnorm layers to eval mode during training using this function
net.train()
for module in net.modules():
if isinstance(module, torch.nn.modules.BatchNorm1d):
module.eval()
if isinstance(module, torch.nn.modules.BatchNorm2d):
module.eval()
if isinstance(module, torch.nn.modules.BatchNorm3d):
module.eval()
However I get really bad training results. The validation score goes to zero straight away. I’ve tried doing the same training without setting the batchnorm layers to eval and that works fine.