How to use orthogonal parameterization with GRU?

I’ve implemented a simple series of functions to create parameterizations of GRU weight matrices that are constrained to be orthogonal. Obviously, there is plenty of room for improvement and abstraction to alternative sequential networks (RNN, LSTM), but this is a dirt-simple starting point for that effort.

The overall strategy is to unstack the weight matrices, use Cayley transformations on the individual parts, and then concatenate them together.

import torch
from torch import nn
from torch.nn.utils import weight_norm
import torch.nn.utils.parametrize as parametrize


def unpack_gru_weights(X):
    nrow, ncol = X.shape
    return torch.split(X, nrow//3)


def skew_symmetric(X):
    nrow, ncol = X.shape
    assert nrow == ncol  # this function assumes a square matrix
    A = torch.triu(X, diagonal=1)
    # exclude diagonal means A[i,i] = 0 for all i
    # main diagonal and all diagonals below are 0, so we can just add.
    return A - A.transpose(0, 1)  # subtract s.t. A.t() = - A


def caley_map(X):
    A = skew_symmetric(X)
    eye = torch.eye(A.shape[0])
    return torch.linalg.solve(A=eye + A, B=eye - A)


class GruOrthogonal(nn.Module):
    def forward(self, X):
        # The GRU weight matrices are weight_hh_l{k} = [W_hr|W_hz|W_hn]
        #                         and weight_ih_l{k} = [W_ir|W_iz|W_in]
        # We have no need to distinguish the 2 cases, so we name them W_xr, W_xz, W_xn
        w_xr, w_xz, w_xn = unpack_gru_weights(X)
        w_xr = caley_map(w_xr)
        w_xz = caley_map(w_xz)
        w_xn = caley_map(w_xn)
        Xq = torch.cat([w_xr, w_xz, w_xn])
        assert X.shape == Xq.shape
        return Xq


if __name__ == "__main__":
    model_dim = 2
    n_batch = 4
    max_time = 20
    norm_dim = 1

    input_tensor = torch.randn((max_time, n_batch, model_dim))

    net = nn.GRU(model_dim, model_dim, num_layers=1, batch_first=False)
    parametrize.register_parametrization(net, "weight_hh_l0", GruOrthogonal())

    net.forward(input_tensor)

    Q = net.all_weights[0][1]
    print(Q @ Q.t())

The print gives us a quick visual confirmation that this works: the 2 x 2 sub-matrices along the main diagonal shows that they are 2 x 2 identity matrices, as expected. All other sub-matrices can be ignored, because they are products of sub-matrices that are irrelevant for our goals here. (Numerical imprecision might mean that some of the elements are not precisely 0.)

I elected not to use SVD because (1) I didn’t want to worry about correcting sign indeterminism and (2) I suspect SVD requires the same or greater computational cost, compared to solving the linear system.

1 Like