Currently, weight_norm
function works only with nn.Module
, are there any plans to expand the function on nn.Parameter
as well?
weight_norm
works by registering a forward_pre_hook on the Module
.
I don’t think there is currently any way to register a forward_pre_hook on a Parameter
.
But you could do this…
Instead of declaring, then using a parameter called myparam
, you declare two parameters myparam_g
and myparam_v
model.myparam_g = Parameter(_norm(myparam, dim).data)
model.myparam_v = Parameter(myparam.data)
Then you write a function that calculates myparam
. For example, this generic function will do
def _norm(p, dim):
"""Computes the norm over all dimensions except dim"""
if dim is None:
return p.norm()
elif dim == 0:
output_size = (p.size(0),) + (1,) * (p.dim() - 1)
return p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size)
elif dim == p.dim() - 1:
output_size = (1,) * (p.dim() - 1) + (p.size(-1),)
return p.contiguous().view(-1, p.size(-1)).norm(dim=0).view(*output_size)
else:
return _norm(p.transpose(0, dim), 0).transpose(0, dim)
def compute_weight(model, name):
g = getattr(model, name + '_g')
v = getattr(model, name + '_v')
return v * (g / _norm(v, self.dim))
Then you can use it in the forward pass to calculate myparam
def forward(..):
myparam = calculate_weight(self, `myparam`)
...
The above functions are adapted from the source code for weight_norm
.
2 Likes