I implemented weight standardization via register_forward_pre_hook
as follows:
from torch.nn.parameter import Parameter
class WeightStandardization(object):
def __init__(self, weight_name, dim):
if dim is None:
dim = 0
self.weight_name = weight_name
self.dim = dim
def compute_weight(self, module):
w = getattr(module, self.weight_name + '_0')
w_mean = w.mean(dim=self.dim, keepdim=True)
w = w - w_mean
std = w.std(dim=self.dim, keepdim=True) + 1e-5
w = w / std
return w
@staticmethod
def apply(module, weight_name, dim):
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, WeightStandardization) and hook.weight_name == weight_name:
raise RuntimeError("Cannot register two weight_standardization hooks on the same parameter {}".format(weight_name))
if dim is None:
dim = 0
fn = WeightStandardization(weight_name, dim)
weight = getattr(module, weight_name)
# add weight before WS as parameter
module.register_parameter(weight_name + '_0', Parameter(weight.data))
# remove w from parameter list
delattr(module, weight_name)
setattr(module, weight_name, fn.compute_weight(module))
if weight_name in module._parameters:
del module._parameters[weight_name]
# recompute weight before every forward()
module.register_forward_pre_hook(fn)
return fn
def remove(self, module):
weight = self.compute_weight(module)
delattr(module, self.weight_name)
del module._parameters[self.weight_name + '_0']
setattr(module, self.weight_name, Parameter(weight.data))
def __call__(self, module, inputs):
setattr(module, self.weight_name, self.compute_weight(module))
def weight_standardization(module, weight_name='weight', dim=0):
"""
Applies weight standardization to a parameter in the given module.
"""
WeightStandardization.apply(module, weight_name, dim)
return module
def remove_weight_standardization(module, weight_name='weight'):
r"""Removes the weight standardization reparameterization from a module.
Args:
module (Module): containing module
weight_name (str, optional): weight_name of weight parameter
"""
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, WeightStandardization) and hook.weight_name == weight_name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module
raise ValueError("weight_standardization of '{}' not found in {}"
.format(weight_name, module))
It’s directly modified from Pytorch’s weight_norm.py
. I also implemented a weight standardized version of Linear
as
class Linear_WS(nn.Linear):
def forward(self, input):
weight = self.weight
weight_mean = weight.mean(dim=0, keepdim=True)
weight = weight - weight_mean
std = weight.std(dim=0, keepdim=True) + 1e-5
weight = weight / std
return F.linear(input, weight, self.bias)
Model with Linear_WS
trains OK while model with
l0 = nn.Linear(in_features=input_dim, out_features=output_dim)
l0 = weight_standardization(l0, weight_name='weight', dim=0)
doesn’t converge at all though they are expected to behave the same way.
I’m out of my mind to figure out why the forward pre hook not working as expected. Is there any misunderstanding on the use of forward pre hook here?
F.Y.I. version of Pytorch is 1.6.0