Hello,
I am interested in taking a model with pretrained weights, and replacing all 2d convolution layers by a version of those that incorporates weight standardization without losing the pretrained values in the layer. An implementation of such a StdConv2d
layer could be something as simple as:
class StdConv2d(nn.Conv2d):
def forward(self, x):
w = self.weight
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
w = (w - m) / torch.sqrt(v + 1e-10)
return F.conv2d(x, w, self.bias, self.stride, self.padding,
self.dilation, self.groups)
which I got from the BiT repo here.
I have been following this approach to replace the layers, which means that I am running:
import torch, torch.nn as nn, torchvision.models as models
def get_layer(model, name):
layer = model
for attr in name.split("."):
layer = getattr(layer, attr)
return layer
def set_layer(model, name, layer):
try:
attrs, name = name.rsplit(".", 1)
model = get_layer(model, attrs)
except ValueError:
pass
setattr(model, name, layer)
class StdConv2d(nn.Conv2d):
def forward(self, x):
w = self.weight
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
w = (w - m) / torch.sqrt(v + 1e-10)
return F.conv2d(x, w, self.bias, self.stride,
self.padding, self.dilation, self.groups)
and then:
# let us replace conv2d layers in a resnet18; are pretrained weights kept?
model = models.resnet18(pretrained=True)
old_first_layer_weights_avg = model.conv1.weight.mean()
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
# Get current bn layer
conv = get_layer(model, name)
# Create new in layer
new_conv = StdConv2d(conv.in_channels, conv.out_channels,
conv.kernel_size, conv.stride, conv.padding,
conv.dilation, conv.groups, conv.bias,
conv.padding_mode)
set_layer(model, name, new_conv)
new_first_layer_weight_avg = model.conv1.weight.mean()
print(old_first_layer_weights_avg == new_first_layer_weight_avg )
Which gives false
, meaning that I am losing the pretrained weights here because StdConv2d
is randomly initialized.
A related issue is that the above works because resnet18
does not have conv2d layers with a bias, but when there is a bias, I get an error. This can be seen by trying to replace the conv layers with the above code in e.g. alexnet
, which will throw the following exception:
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
It seems to me that there must be a simple way to add standardization to conv layers while keeping the pretrained weights intact, but my anti-expertise in OOP prevents me from finding it, any help is appreciated
Thanks!
Adrian