Weight_normalization for nn.Parameter

Currently, weight_norm function works only with nn.Module, are there any plans to expand the function on nn.Parameter as well?

weight_norm works by registering a forward_pre_hook on the Module.
I don’t think there is currently any way to register a forward_pre_hook on a Parameter.

But you could do this…
Instead of declaring, then using a parameter called myparam, you declare two parameters myparam_g and myparam_v

model.myparam_g = Parameter(_norm(myparam, dim).data)
model.myparam_v = Parameter(myparam.data)

Then you write a function that calculates myparam. For example, this generic function will do

def _norm(p, dim):
    """Computes the norm over all dimensions except dim"""
    if dim is None:
        return p.norm()
    elif dim == 0:
        output_size = (p.size(0),) + (1,) * (p.dim() - 1)
        return p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size)
    elif dim == p.dim() - 1:
        output_size = (1,) * (p.dim() - 1) + (p.size(-1),)
        return p.contiguous().view(-1, p.size(-1)).norm(dim=0).view(*output_size)
    else:
        return _norm(p.transpose(0, dim), 0).transpose(0, dim)

def compute_weight(model, name):
    g = getattr(model, name + '_g')
    v = getattr(model, name + '_v')
    return v * (g / _norm(v, self.dim))

Then you can use it in the forward pass to calculate myparam

def forward(..):
    myparam = calculate_weight(self, `myparam`)
    ...

The above functions are adapted from the source code for weight_norm.

2 Likes