How to use orthogonal parameterization with GRU?

I am attempting to create a GRU that uses orthogonal parameterization for the weight matrices stored in weight_hh_l0. From the GRU documentation, I know that weight_hh_l0 contains three weight matrices W_hr,W_hz, and W_hn concatenated together. The problem that I’m facing is that because the GRU module stores the matrices in a concatenated format, applying the orthogonal parameterization via torch.nn.utils.parametrizations.orthogonal yields a matrix that is entirely 0 for the bottom two-thirds of the rows.

My question is, what is the correct way to enforce orthogonality constraints on GRU weight matrices so that each of the submatrices of the GRU weights are orthogonal? I want W_hr, W_hz, and W_hn to all be orthogonal, instead of W_hr orthogonal with the other matrices set to the zero matrix.

This is a minimal example demonstrating what I’m observing.

import torch
from torch import nn
from torch.nn.utils.parametrizations import orthogonal
model_dim = 2
n_batch = 4
max_time = 20

input_tensor = torch.randn((max_time, n_batch, model_dim))

net = nn.GRU(model_dim, model_dim, num_layers=1, batch_first=False)
net = orthogonal(net, name=f"weight_hh_l0")
net.forward(input_tensor)

for name, param in net.named_parameters():
    print(name)
    print(param)

The result is that the weight matrix weight_hh_l0 is only orthogonal in the first third (corresponding to W_hr), but the latter two-thirds (corresponding to W_hz, and W_hn) are all zeros. This is not what I want, because it does not contain three orthogonal sub-matrices. In other words, it seems that orthogonal is naïve about weight_hh_l0 and treats it as a single weight matrix, instead of three independent matrices concatenated together, each of which has an orthogonality constraint.

parametrizations.weight_hh_l0.original
Parameter containing:
tensor([[-1.,  0.],
        [ 0., -1.],
        [ 0.,  0.],
        [ 0.,  0.],
        [ 0.,  0.],
        [ 0.,  0.]], requires_grad=True)

Is there a way that we can enforce orthogonality of all three matrices W_hr, W_hz, and W_hn using torch.nn.GRU and torch.nn.utils.parametrizations.orthogonal? How so?

Hi user0!

Probably the simplest way would be to implement your own parametrization
that splits weight_hh_l0 into the three actual weight matrices, orthogonalizes
them individually, and cats the orthogonal matrices back together.

A sound way to orthogonalize a matrix is to use its singular-value decomposition:

>>> import torch
>>> torch.__version__
'2.0.0'
>>> thin = torch.randn (5, 3)
>>> ortho_thin = torch.linalg.svd (thin, full_matrices = False)[0]
>>> ortho_thin.shape
torch.Size([5, 3])
>>> ortho_thin.T @ ortho_thin
tensor([[ 1.0000e+00, -1.0617e-07, -7.4506e-08],
        [-1.0617e-07,  1.0000e+00,  5.4017e-08],
        [-7.4506e-08,  5.4017e-08,  1.0000e+00]])

(This might not be as good / fast an algorithm as the one under the hood of
torch.nn.utils.parametrizations.orthogonal(), but orthogonal()
doesn’t directly fit the use case you describe.)

Not that I am aware of.

Best.

K. Frank

Hi, K. Frank!

Thanks for the note. I understand what you’re saying about creating a custom parameterization to work around the way that the GRU weight matrices are concatenated. I was hoping that there was a simpler way, but I suppose there isn’t.

But I don’t understand when you say that I would have to use SVD in place of orthogonal. Is this because orthogonal only works on nn.Module, vice torch.Tensor? Or something else?

Moreover, doesn’t SVD suffer from sign ambiguity? If we’re computing an SVD factorization at each iteration, then we would need to somehow persist sign information so that we’re not using U as one iteration and -U at the next iteration.

Thanks

You’d need to implement that on your own, as all the weights are packed together in the GRU implementation. You can take inspiration from the implementations in GitHub - lezcano/geotorch: Constrained optimization toolkit for PyTorch

The way I’d do it is to go from 3 different matrices into a larger one by putting a orthogonality constraint on the first one and no constraints on the last 2 and then concatenate the three to form the larger tensor. You can even keep the 2/3’s merged into one tensor, which may be slightly more efficient.

Hi Lezcano,

Thanks for the reply. My goal is to have orthogonality constraints on all 3 matrices. Why do you suggest only constraining the first one? That doesn’t seem to address my question.

I misunderstood what you wanted to do. Then do exactly the same thing but with three matrices. You may need to have a look at how P.orthogonal is implemented, and take inspiration on it to implement your version. It shouldn’t be too difficult though.

Hi user0!

Three things are conspiring to make this more involved than it might have been.

First, orthogonal() doesn’t know anything about GRU (nor is there really any
way for it to). When you call orthogonal(net, name=f"weight_hh_l0"),
you’re simply telling orthogonal() to act on the "weight_hh_l0" parameter
of your netorthogonal() doesn’t know that weight_hh_l0 should be
interpreted as consisting of three separate matrices that should be individually
orthogonalized.

Second, pytorch does not provide a functional version of GRU. Consider
Linear, which does have a functional version: You could (but you don’t need
to do it this way) store weight as a free-standing trainable parameter, apply
orthogonal() to your model and weight, and in your model’s forward function,
pass the orthogonalized weight to torch.nn.functional.linear(). You
can’t use this scheme with GRU, as it lacks a functional version.

Third (and probably poor design in my opinion), orthogonal() doesn’t expose
its orthogonalization algorithm to the outside world. You can see from the code
for orthogonal() that its orthogonalization algorithm is intertwined with its
implementation as a parametrization. This makes it impractical to reuse
orthogonal()'s internal algorithm in a custom parametrization you write
that understands the structure of GRU.

I don’t know enough about pytorch’s svd() and your use case to know whether
this would be a problem in practice, but it’s certainly a theoretical possibility.
Provided that the relevant rows of weight_hh_l0 don’t become (nearly)
collinear (which would be a problem for any orthogonalization algorithm), you
should be able to resolve any such sign ambiguity by flipping (if necessary) the
rows of the orthogonalized matrices so that they have positive inner product
with the corresponding rows of the original matrices.

Best.

K. Frank

I’ve implemented a simple series of functions to create parameterizations of GRU weight matrices that are constrained to be orthogonal. Obviously, there is plenty of room for improvement and abstraction to alternative sequential networks (RNN, LSTM), but this is a dirt-simple starting point for that effort.

The overall strategy is to unstack the weight matrices, use Cayley transformations on the individual parts, and then concatenate them together.

import torch
from torch import nn
from torch.nn.utils import weight_norm
import torch.nn.utils.parametrize as parametrize


def unpack_gru_weights(X):
    nrow, ncol = X.shape
    return torch.split(X, nrow//3)


def skew_symmetric(X):
    nrow, ncol = X.shape
    assert nrow == ncol  # this function assumes a square matrix
    A = torch.triu(X, diagonal=1)
    # exclude diagonal means A[i,i] = 0 for all i
    # main diagonal and all diagonals below are 0, so we can just add.
    return A - A.transpose(0, 1)  # subtract s.t. A.t() = - A


def caley_map(X):
    A = skew_symmetric(X)
    eye = torch.eye(A.shape[0])
    return torch.linalg.solve(A=eye + A, B=eye - A)


class GruOrthogonal(nn.Module):
    def forward(self, X):
        # The GRU weight matrices are weight_hh_l{k} = [W_hr|W_hz|W_hn]
        #                         and weight_ih_l{k} = [W_ir|W_iz|W_in]
        # We have no need to distinguish the 2 cases, so we name them W_xr, W_xz, W_xn
        w_xr, w_xz, w_xn = unpack_gru_weights(X)
        w_xr = caley_map(w_xr)
        w_xz = caley_map(w_xz)
        w_xn = caley_map(w_xn)
        Xq = torch.cat([w_xr, w_xz, w_xn])
        assert X.shape == Xq.shape
        return Xq


if __name__ == "__main__":
    model_dim = 2
    n_batch = 4
    max_time = 20
    norm_dim = 1

    input_tensor = torch.randn((max_time, n_batch, model_dim))

    net = nn.GRU(model_dim, model_dim, num_layers=1, batch_first=False)
    parametrize.register_parametrization(net, "weight_hh_l0", GruOrthogonal())

    net.forward(input_tensor)

    Q = net.all_weights[0][1]
    print(Q @ Q.t())

The print gives us a quick visual confirmation that this works: the 2 x 2 sub-matrices along the main diagonal shows that they are 2 x 2 identity matrices, as expected. All other sub-matrices can be ignored, because they are products of sub-matrices that are irrelevant for our goals here. (Numerical imprecision might mean that some of the elements are not precisely 0.)

I elected not to use SVD because (1) I didn’t want to worry about correcting sign indeterminism and (2) I suspect SVD requires the same or greater computational cost, compared to solving the linear system.

1 Like

That’s exactly what I had in mind. I really like the very clean implementation! The only stylistic point I would make is: You may want to have a look at .mT which does the same thing as your transpose and t but in a more concise way (the m is for “matrix transpose”).

Ah, I see – since I posted this code, I traded .transpose for .T. I’ve read the documentation for .mT and .T and I don’t think .mT is a good choice. The documentation states

The use of Tensor.T() on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. Consider mT to transpose batches of matrices or x.permute(*torch.arange(x.ndim - 1, -1, -1)) to reverse the dimensions of a tensor.

In this usage, we only have 2-tensors, so .T is doing exactly what we want. In a scenario where the user supplies a 3-tensor or higher-order tensor, that usage is unanticipated by these methods, so raising an error is the appropriate outcome.

Indeed, overly-general code is how I got here in the first place: torch.nn.utils.parametrizations.orthogonal didn’t object when I applied it to the weight tensors of a GRU, so I had to do some very annoying tracing to discover that the out-of-the-box orthogonal parameterization is too generalized & too naïve to be used with GRUs.

Also, I’ve realized that unpack_gru_weights is redundant. Instead, I should just do w_xr, w_xz, w_xn = torch.split(W, W.shape[0] // 3). I would edit that post, but for some reason there’s no edit button.

The point is that .mT is almost always what you want on higher order tensors, and .T is almost never what you want on higher order. But of course, to each their own :slight_smile:

Note that parametrizations.orthogonal is not overly general. It just has no way to know that what you want is not to make the whole matrix orthogonal but split it in three and then make each of them orthogonal.

I think if you re-read my post, you’ll see that is exactly what I’m saying. The weight in nn.GRU are 2-tensors, so splitting them also produces 2-tensors, which is why using T is preferable to mT.

Well, it’s definitely over-general because part of the mathematical definition of an orthogonal matrix is that an orthogonal matrix is square. Users can apply torch.nn.utils.parametrizations.orthogonal to non-square matrices; indeed, this is the origination of the unexpected behavior outlined in my first post.

Your reply is consistent with what I mean by naïve. If there’s no intention to make the out-of-the-box function torch.nn.utils.parametrizations.orthogonal compatible with the out-of-the-box classes nn.GRU and nn.LSTM, then there should be a note to that effect somewhere in the docs. Or it should raise an error when it is called on a non-square matrix, which would be consistent with the mathematical definition of orthogonal.

Even better, an alternative way to make parametrizations.orthogonal compatible with native pytorch functionality would be to simply not store the weight matrices in this concatenated format! Then it would work as advertised, without users being compelled to write their own parameterization code.

If you think that any of the implementations in core can be improved, we always accept PRs. Now, I don’t think that what you propose can be achieved without incurring on a perf hit on the most common use case for LSTM and GRU, that is, when they do not have any parametrisation registered.

I created an issue on the PyTorch github last month. It outlines several alternative paths to make the orthogonal parameterization clearer in the case of non-square matrices. parameterizations.orthogonal does not work as intended with nn.GRU or nn.LSTM · Issue #102740 · pytorch/pytorch · GitHub