For those who seek a solution and not snobbish mentoring:
def replace_bn(m, name):
for attr_str in dir(m):
target_attr = getattr(m, attr_str)
if type(target_attr) == torch.nn.BatchNorm2d:
print('replaced: ', name, attr_str)
setattr(m, attr_str, SynchronizedBatchNorm2d(target_attr.num_features, target_attr.eps, target_attr.momentum, target_attr.affine))
for n, ch in m.named_children():
replace_bn(ch, n)
replace_bn(net, "net")