How to add orthogonal constrain to weight?

I’m trying to add orthogonal constrain to the weight of a Linear layer. I find this is similar to the function torch.nn.utils.weight_norm, so I write a similar class WeightOrtho to orthogonalize the weight. Each time the forward pre hook assigns a new orthogonal weight to the layer. But this seems not working. The weight is not updating in the training. Is there a good way to add orthogonal constrain to weight?

weight_ortho(torch.nn.Linear(6, 2, bias=True), name='weight', dim=0)
class WeightOrtho(object):
    def __init__(self, name, dim):
        if dim is None:
            dim = -1
        self.name = name
        self.dim = dim

    def compute_weight(self, module):
        w = getattr(module, self.name)
        return torch.nn.Parameter(torch.transpose(torch.qr(torch.transpose(w, 0, 1))[0], 0, 1))

    @staticmethod
    def apply(module, name, dim):
        for k, hook in module._forward_pre_hooks.items():
            if isinstance(hook, WeightOrtho) and hook.name == name:
                raise RuntimeError("Cannot register two weight_ortho hooks on "
                                   "the same parameter {}".format(name))

        if dim is None:
            dim = -1

        fn = WeightOrthho(name, dim)

        weight = getattr(module, name)

        # remove w from parameter list
        del module._parameters[name]

        module.register_parameter(name , Parameter(weight.data))
        setattr(module, name, fn.compute_weight(module))

        # recompute weight before every forward()
        module.register_forward_pre_hook(fn)

        return fn

    def remove(self, module):
        weight = self.compute_weight(module)
        delattr(module, self.name)
        del module._parameters[self.name]
        module.register_parameter(self.name, Parameter(weight.data))

    def __call__(self, module, inputs):
        setattr(module, self.name, self.compute_weight(module))

def weight_ortho(module, name='weight', dim=0):
    WeightOrtho.apply(module, name, dim)
    return module

Have a look at this link.

It’s an iterative orthogonalization procedure which you have to call iteratively until an acted upon linear layer converges to orthogonality. If you are wondering about their implementation, self.params.map_beta is just a scalar hyperparameter which defaults to 0.001.

A bit simplified, their procedure looks as follows:

beta = 0.001

def orthogonalize(self):
        """
        Orthogonalize the mapping.
        """
        W = self.linear_layer.weight.data
        W.copy_((1 + beta) * W - beta * W.mm(W.transpose(0, 1).mm(W)))

A bit late to the party, but starting in 1.10 this feature is supported natively in PyTorch:
https://pytorch.org/docs/stable/generated/torch.nn.utils.parametrizations.orthogonal.html

Leaving this here in case someone else bumps into this post in the future.