I’m trying to add orthogonal constrain to the weight of a Linear layer. I find this is similar to the function torch.nn.utils.weight_norm, so I write a similar class WeightOrtho to orthogonalize the weight. Each time the forward pre hook assigns a new orthogonal weight to the layer. But this seems not working. The weight is not updating in the training. Is there a good way to add orthogonal constrain to weight?
weight_ortho(torch.nn.Linear(6, 2, bias=True), name='weight', dim=0)
class WeightOrtho(object):
def __init__(self, name, dim):
if dim is None:
dim = -1
self.name = name
self.dim = dim
def compute_weight(self, module):
w = getattr(module, self.name)
return torch.nn.Parameter(torch.transpose(torch.qr(torch.transpose(w, 0, 1))[0], 0, 1))
@staticmethod
def apply(module, name, dim):
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, WeightOrtho) and hook.name == name:
raise RuntimeError("Cannot register two weight_ortho hooks on "
"the same parameter {}".format(name))
if dim is None:
dim = -1
fn = WeightOrthho(name, dim)
weight = getattr(module, name)
# remove w from parameter list
del module._parameters[name]
module.register_parameter(name , Parameter(weight.data))
setattr(module, name, fn.compute_weight(module))
# recompute weight before every forward()
module.register_forward_pre_hook(fn)
return fn
def remove(self, module):
weight = self.compute_weight(module)
delattr(module, self.name)
del module._parameters[self.name]
module.register_parameter(self.name, Parameter(weight.data))
def __call__(self, module, inputs):
setattr(module, self.name, self.compute_weight(module))
def weight_ortho(module, name='weight', dim=0):
WeightOrtho.apply(module, name, dim)
return module