Replacing convs modules with custom convs, then NotImplementedError

For those who seek a solution and not snobbish mentoring:

def replace_bn(m, name):
    for attr_str in dir(m):
        target_attr = getattr(m, attr_str)
        if type(target_attr) == torch.nn.BatchNorm2d:
            print('replaced: ', name, attr_str)
            setattr(m, attr_str, SynchronizedBatchNorm2d(target_attr.num_features, target_attr.eps, target_attr.momentum, target_attr.affine))
    for n, ch in m.named_children():
        replace_bn(ch, n)
        
replace_bn(net, "net")
12 Likes