I am having the issue that everyone else has, where a model that uses BatchNorm has poorer accuracy when using DDP:
According to this, I am suppose to patch Batch Norm somehow:
def monkey_patch_bn():
# print(inspect.getsource(torch.nn.functional.batch_norm))
def batch_norm(input, running_mean, running_var, weight=None, bias=None,
training=False, momentum=0.1, eps=1e-5):
if training:
size = input.size()
size_prods = size[0]
for i in range(len(size) - 2):
size_prods *= size[i + 2]
if size_prods == 1:
raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
return torch.batch_norm(
input, weight, bias, running_mean, running_var,
training, momentum, eps, False
)
torch.nn.functional.batch_norm = batch_norm
But I am not sure how to do it if my code is like this:
def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation, bn_running_avg=False):
return nn.Sequential(nn.Conv2d(in_planes, out_planes,
kernel_size=kernel_size, stride=stride,
padding=dilation if dilation > 1 else pad,
dilation=dilation, bias=False),
nn.BatchNorm2d(out_planes,
track_running_stats=bn_running_avg))