I am interested in taking a model with pretrained weights, and replacing all 2d convolution layers by a version of those that incorporates weight standardization without losing the pretrained values in the layer. An implementation of such a
StdConv2d layer could be something as simple as:
class StdConv2d(nn.Conv2d): def forward(self, x): w = self.weight v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) w = (w - m) / torch.sqrt(v + 1e-10) return F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
which I got from the BiT repo here.
I have been following this approach to replace the layers, which means that I am running:
import torch, torch.nn as nn, torchvision.models as models def get_layer(model, name): layer = model for attr in name.split("."): layer = getattr(layer, attr) return layer def set_layer(model, name, layer): try: attrs, name = name.rsplit(".", 1) model = get_layer(model, attrs) except ValueError: pass setattr(model, name, layer) class StdConv2d(nn.Conv2d): def forward(self, x): w = self.weight v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) w = (w - m) / torch.sqrt(v + 1e-10) return F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
# let us replace conv2d layers in a resnet18; are pretrained weights kept? model = models.resnet18(pretrained=True) old_first_layer_weights_avg = model.conv1.weight.mean() for name, module in model.named_modules(): if isinstance(module, nn.Conv2d): # Get current bn layer conv = get_layer(model, name) # Create new in layer new_conv = StdConv2d(conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, conv.dilation, conv.groups, conv.bias, conv.padding_mode) set_layer(model, name, new_conv) new_first_layer_weight_avg = model.conv1.weight.mean() print(old_first_layer_weights_avg == new_first_layer_weight_avg )
false, meaning that I am losing the pretrained weights here because
StdConv2d is randomly initialized.
A related issue is that the above works because
resnet18 does not have conv2d layers with a bias, but when there is a bias, I get an error. This can be seen by trying to replace the conv layers with the above code in e.g.
alexnet, which will throw the following exception:
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
It seems to me that there must be a simple way to add standardization to conv layers while keeping the pretrained weights intact, but my anti-expertise in OOP prevents me from finding it, any help is appreciated