Why this forward pre hook not working as expected?

I implemented weight standardization via register_forward_pre_hook as follows:

from torch.nn.parameter import Parameter

class WeightStandardization(object):
    def __init__(self, weight_name, dim):
        if dim is None:
            dim = 0
        self.weight_name = weight_name
        self.dim  = dim

    def compute_weight(self, module):
        w = getattr(module, self.weight_name + '_0')
        w_mean = w.mean(dim=self.dim, keepdim=True)
        w = w - w_mean
        std = w.std(dim=self.dim, keepdim=True) + 1e-5
        w = w / std
        return w

    def apply(module, weight_name, dim):
        for k, hook in module._forward_pre_hooks.items():
            if isinstance(hook, WeightStandardization) and hook.weight_name == weight_name:
                raise RuntimeError("Cannot register two weight_standardization hooks on the same parameter {}".format(weight_name))

        if dim is None:
            dim = 0

        fn = WeightStandardization(weight_name, dim)

        weight = getattr(module, weight_name)

        # add weight before WS as parameter
        module.register_parameter(weight_name + '_0', Parameter(weight.data))

        # remove w from parameter list
        delattr(module, weight_name)
        setattr(module, weight_name, fn.compute_weight(module))
        if weight_name in module._parameters:
            del module._parameters[weight_name]

        # recompute weight before every forward()

        return fn

    def remove(self, module):
        weight = self.compute_weight(module)
        delattr(module, self.weight_name)
        del module._parameters[self.weight_name + '_0']
        setattr(module, self.weight_name, Parameter(weight.data))

    def __call__(self, module, inputs):
        setattr(module, self.weight_name, self.compute_weight(module))

def weight_standardization(module, weight_name='weight', dim=0):
    Applies weight standardization to a parameter in the given module.
    WeightStandardization.apply(module, weight_name, dim)
    return module

def remove_weight_standardization(module, weight_name='weight'):
    r"""Removes the weight standardization reparameterization from a module.

        module (Module): containing module
        weight_name (str, optional): weight_name of weight parameter
    for k, hook in module._forward_pre_hooks.items():
        if isinstance(hook, WeightStandardization) and hook.weight_name == weight_name:
            del module._forward_pre_hooks[k]
            return module

    raise ValueError("weight_standardization of '{}' not found in {}"
                     .format(weight_name, module))

It’s directly modified from Pytorch’s weight_norm.py. I also implemented a weight standardized version of Linear as

class Linear_WS(nn.Linear):
    def forward(self, input):
        weight = self.weight
        weight_mean = weight.mean(dim=0, keepdim=True)
        weight = weight - weight_mean
        std = weight.std(dim=0, keepdim=True) + 1e-5
        weight = weight / std
        return F.linear(input, weight, self.bias)

Model with Linear_WS trains OK while model with

l0 = nn.Linear(in_features=input_dim, out_features=output_dim)
l0 = weight_standardization(l0, weight_name='weight', dim=0)

doesn’t converge at all though they are expected to behave the same way.

I’m out of my mind to figure out why the forward pre hook not working as expected. Is there any misunderstanding on the use of forward pre hook here?

F.Y.I. version of Pytorch is 1.6.0

Could you use forward hooks and check verify that the weights were indeed standardized?
The forward hook should be executed during the forward pass after the pre hook, so you should be able to check the weight stats and make sure the workflow is correct.

Just did it. The weight is indeed standardized. The forward hook used for verification is as:

def verify_WS(module, input, output):
    w = getattr(module, 'weight_0')
    w_mean = w.mean(dim=0, keepdim=True)
    w = w - w_mean
    std = w.std(dim=0, keepdim=True) + 1e-5
    w = w / std

    w2 = getattr(module, 'weight')

    mask = w != w2
    if mask.any():
        d = torch.abs(w-w2)
        print('d=', d)
        print('WS verification passed')

The forward process seems OK.

Any more suggestions?

I’m not sure what to suggest, as it seems your code works properly based on your last post, no?

I suspect the backward pass might be broken somehow.

The only difference between my implementation of weight standardization hook function and Pytorch1.6.0’s weight normalization is within the compute_weight() function, where my implementation is explicitly done in python meanwhile Pytorch1.6.0’s implementation wraps the actual computation in another C level _weight_norm() function.

If you suspect the backward to be broken, did you check the gradients and compared them to the theoretical gradients or did you use gradcheck?