Change batch norm gammas init

Is there a way to change the initialization of batchnorm gammas of my model, in a way that works for both the affine=True and affine=False settings? For pytorch version >= 1.2.0 these are set to one, but I would like to change that.

The issue is that when affine=False the batchnorm gammas are set to None and it’s not clear if that’s okay to play with.

You should be able to use torch.nn.init, if you check for None before calling an init method on the affine parameters:

def bn_weight_init(m):
    if m.weight is not None and m.bias is not None:
        torch.nn.init.normal_(m.weight)
        torch.nn.init.normal_(m.bias)

bn = nn.BatchNorm2d(3, affine=True)
print(bn.weight, bn.bias)
bn.apply(bn_weight_init)
print(bn.weight, bn.bias)

bn = nn.BatchNorm2d(3, affine=False)
print(bn.weight, bn.bias)
bn.apply(bn_weight_init)
print(bn.weight, bn.bias)

Thanks for your help :slight_smile: However my objective is to change gammas and betas even when they are not learnable (affine=False).

By deafault affine=False means that gammas are set to 1s, and are not parameters. I want to change this to some other value than 1 and still have BN gammas not learnable.

The best would be to rewrite my own BN module but I can’t figure out where the gammas are set to 1 in the case of afffine=False since they are set to None.

Thanks!

Ah OK, in that case I would use affine = True, set the desired values via bn.weight.fill_() or .copy_(), and freeze the affine parameters by setting their requires_grad attribute to False.
Would that work for you?

Thanks. My issue is that I have a really annoying use case (in meta learning) where this hack doesn’t work. All my weights must be concatenated into a single trainable tensor, and passed to the network as an argument when doing the forward pass. So this trick doesn’t work because that would mean having some elements in my weight vector be non-trainable.

In short I need gammas not to be parameters when affine=False otherwise they break everything… Do you think tweaking the BN layer source code is required here? I can’t figure out why gamma=None gets interpreted as gamma=1 in the BN function.

There are internal checks as seen here, which replace the None values with the default ones.

Thanks for the information about the use case.
I’m unfortunately unsure, what the best way would be (probably reimplementing the batchnorm layer manually).
However, this hacky trick might also work for you:

bn = nn.BatchNorm2d(3, affine=False)
delattr(bn, 'weight')
delattr(bn, 'bias')
bn.register_buffer('weight', torch.ones(3) * 2)
bn.register_buffer('bias', torch.ones(3) * 2)

Note that this is completely untested code and might break in various ways, but you could try it and see, if that would be working.

Thanks a lot! I’ll have a go and give up if it doesn’t work, since changing C code sounds like it’s above my (PhD) paygrade… :stuck_out_tongue:

In the worst case, you can use my manual and slow reference implementation to hack around. :slight_smile:

1 Like