Replacing nn.Conv2d with standardized convolutions but keeping weight values

Hello,

I am interested in taking a model with pretrained weights, and replacing all 2d convolution layers by a version of those that incorporates weight standardization without losing the pretrained values in the layer. An implementation of such a StdConv2d layer could be something as simple as:

class StdConv2d(nn.Conv2d):
  def forward(self, x):
    w = self.weight
    v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
    w = (w - m) / torch.sqrt(v + 1e-10)
    return F.conv2d(x, w, self.bias, self.stride, self.padding, 
                          self.dilation, self.groups)

which I got from the BiT repo here.

I have been following this approach to replace the layers, which means that I am running:

import torch, torch.nn as nn, torchvision.models as models

def get_layer(model, name):
    layer = model
    for attr in name.split("."):
        layer = getattr(layer, attr)
    return layer

def set_layer(model, name, layer):
    try:
        attrs, name = name.rsplit(".", 1)
        model = get_layer(model, attrs)
    except ValueError:
        pass
    setattr(model, name, layer)
    
class StdConv2d(nn.Conv2d):
    def forward(self, x):
        w = self.weight
        v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
        w = (w - m) / torch.sqrt(v + 1e-10)
        return F.conv2d(x, w, self.bias, self.stride, 
                              self.padding, self.dilation, self.groups)

and then:

# let us replace conv2d layers in a resnet18; are pretrained weights kept?
model = models.resnet18(pretrained=True)
old_first_layer_weights_avg = model.conv1.weight.mean()

for name, module in model.named_modules():
    if isinstance(module, nn.Conv2d):
        # Get current bn layer
        conv = get_layer(model, name)
        # Create new in layer
        new_conv = StdConv2d(conv.in_channels, conv.out_channels, 
                             conv.kernel_size, conv.stride, conv.padding, 
                             conv.dilation, conv.groups,  conv.bias, 
                             conv.padding_mode)
        set_layer(model, name, new_conv)

new_first_layer_weight_avg = model.conv1.weight.mean()
print(old_first_layer_weights_avg == new_first_layer_weight_avg )

Which gives false, meaning that I am losing the pretrained weights here because StdConv2d is randomly initialized.

A related issue is that the above works because resnet18 does not have conv2d layers with a bias, but when there is a bias, I get an error. This can be seen by trying to replace the conv layers with the above code in e.g. alexnet, which will throw the following exception:

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

It seems to me that there must be a simple way to add standardization to conv layers while keeping the pretrained weights intact, but my anti-expertise in OOP prevents me from finding it, any help is appreciated :slight_smile:

Thanks!

Adrian

I don’t see any code which would assign or copy the .weight parameter to the new conv layer in your code snippet, so I guess this part is missing.
Something like:

with torch.no_grad():
    new_conv.weight.copy_(conv.weight)
    if new_conv.bias:
        new_conv.bias.copy_(conv.bias)

set_layer(model, name, new_conv)

might work.

Hi!

Thanks for your quick help, I appreciate it a lot. Your solution works 99%, the only thing is that bias seems to be overloaded in nn.Conv2d in a bit of a weird manner. At construction it needs to be a bool, but once the object is initialized, the very same bias becomes either None or a float tensor.

For that reason, if one tries to construct new_conv using conv.bias when conv originally had a bias, the code will crash because it expects a bool. I fixed that as follows:

for name, module in model.named_modules():
    if isinstance(module, nn.Conv2d):
        # Get current bn layer
        conv = get_layer(model, name)
        # Create new in layer, conv.bias is a tensor or None, bias needs to be bool
        b = False if conv.bias is None else True
        new_conv =StdConv2d(conv.in_channels, conv.out_channels, conv.kernel_size, 
                            conv.stride, conv.padding, conv.dilation, conv.groups, 
                             bias=b, padding_mode=conv.padding_mode)
        
        with torch.no_grad():
            new_conv.weight.copy_(conv.weight)
            if b: new_conv.bias.copy_(conv.bias)
        set_layer(model, name, new_conv)

Before closing, can I ask a follow-up question? I think this one is a bit more complex, if there is no easy way to achieve this, please forget it, it’s not so important (I hope).

In this repo where they provide another implementation of weight standardization, at the bottom of the readme they recommend that one should ‘‘only replace convolutional layers that are followed by normalization layers such as BN, GN, etc.’’. Is there any easy modification of the above code that only applies layer replacement when this comes after a normalization layer?

Thanks again!!

This use case might not be trivial using the simple approach by iterating the named_modules() as the actual usage of the conv and normalization layers might be different than their initialization.
E.g. you could first initialize all norm, then all conv layers, but could use them in a different order in the forward. A better approach might be to use torch.fx and use the actual graph to replace modules. The docs would give some examples on graph manipulation.