# How to use orthogonal parameterization with GRU?

I am attempting to create a GRU that uses orthogonal parameterization for the weight matrices stored in `weight_hh_l0`. From the GRU documentation, I know that `weight_hh_l0` contains three weight matrices `W_hr`,`W_hz`, and `W_hn` concatenated together. The problem that I’m facing is that because the GRU module stores the matrices in a concatenated format, applying the orthogonal parameterization via `torch.nn.utils.parametrizations.orthogonal` yields a matrix that is entirely 0 for the bottom two-thirds of the rows.

My question is, what is the correct way to enforce orthogonality constraints on GRU weight matrices so that each of the submatrices of the GRU weights are orthogonal? I want `W_hr`, `W_hz`, and `W_hn` to all be orthogonal, instead of `W_hr` orthogonal with the other matrices set to the zero matrix.

This is a minimal example demonstrating what I’m observing.

``````import torch
from torch import nn
from torch.nn.utils.parametrizations import orthogonal
model_dim = 2
n_batch = 4
max_time = 20

input_tensor = torch.randn((max_time, n_batch, model_dim))

net = nn.GRU(model_dim, model_dim, num_layers=1, batch_first=False)
net = orthogonal(net, name=f"weight_hh_l0")
net.forward(input_tensor)

for name, param in net.named_parameters():
print(name)
print(param)
``````

The result is that the weight matrix `weight_hh_l0` is only orthogonal in the first third (corresponding to `W_hr`), but the latter two-thirds (corresponding to `W_hz`, and `W_hn`) are all zeros. This is not what I want, because it does not contain three orthogonal sub-matrices. In other words, it seems that `orthogonal` is naïve about `weight_hh_l0` and treats it as a single weight matrix, instead of three independent matrices concatenated together, each of which has an orthogonality constraint.

``````parametrizations.weight_hh_l0.original
Parameter containing:
tensor([[-1.,  0.],
[ 0., -1.],
[ 0.,  0.],
[ 0.,  0.],
[ 0.,  0.],
``````

Is there a way that we can enforce orthogonality of all three matrices `W_hr`, `W_hz`, and `W_hn` using `torch.nn.GRU` and `torch.nn.utils.parametrizations.orthogonal`? How so?

Hi user0!

Probably the simplest way would be to implement your own parametrization
that splits `weight_hh_l0` into the three actual weight matrices, orthogonalizes
them individually, and cats the orthogonal matrices back together.

A sound way to orthogonalize a matrix is to use its singular-value decomposition:

``````>>> import torch
>>> torch.__version__
'2.0.0'
>>> thin = torch.randn (5, 3)
>>> ortho_thin = torch.linalg.svd (thin, full_matrices = False)
>>> ortho_thin.shape
torch.Size([5, 3])
>>> ortho_thin.T @ ortho_thin
tensor([[ 1.0000e+00, -1.0617e-07, -7.4506e-08],
[-1.0617e-07,  1.0000e+00,  5.4017e-08],
[-7.4506e-08,  5.4017e-08,  1.0000e+00]])
``````

(This might not be as good / fast an algorithm as the one under the hood of
`torch.nn.utils.parametrizations.orthogonal()`, but `orthogonal()`
doesn’t directly fit the use case you describe.)

Not that I am aware of.

Best.

K. Frank

Hi, K. Frank!

Thanks for the note. I understand what you’re saying about creating a custom `parameterization` to work around the way that the GRU weight matrices are concatenated. I was hoping that there was a simpler way, but I suppose there isn’t.

But I don’t understand when you say that I would have to use SVD in place of `orthogonal`. Is this because `orthogonal` only works on `nn.Module`, vice `torch.Tensor`? Or something else?

Moreover, doesn’t SVD suffer from sign ambiguity? If we’re computing an SVD factorization at each iteration, then we would need to somehow persist sign information so that we’re not using `U` as one iteration and `-U` at the next iteration.

Thanks

You’d need to implement that on your own, as all the weights are packed together in the GRU implementation. You can take inspiration from the implementations in GitHub - lezcano/geotorch: Constrained optimization toolkit for PyTorch

The way I’d do it is to go from 3 different matrices into a larger one by putting a orthogonality constraint on the first one and no constraints on the last 2 and then concatenate the three to form the larger tensor. You can even keep the 2/3’s merged into one tensor, which may be slightly more efficient.

Hi Lezcano,

Thanks for the reply. My goal is to have orthogonality constraints on all 3 matrices. Why do you suggest only constraining the first one? That doesn’t seem to address my question.

I misunderstood what you wanted to do. Then do exactly the same thing but with three matrices. You may need to have a look at how `P.orthogonal` is implemented, and take inspiration on it to implement your version. It shouldn’t be too difficult though.

Hi user0!

Three things are conspiring to make this more involved than it might have been.

First, `orthogonal()` doesn’t know anything about `GRU` (nor is there really any
way for it to). When you call `orthogonal(net, name=f"weight_hh_l0")`,
you’re simply telling `orthogonal()` to act on the `"weight_hh_l0"` parameter
of your `net``orthogonal()` doesn’t know that `weight_hh_l0` should be
interpreted as consisting of three separate matrices that should be individually
orthogonalized.

Second, pytorch does not provide a functional version of `GRU`. Consider
`Linear`, which does have a functional version: You could (but you don’t need
to do it this way) store `weight` as a free-standing trainable parameter, apply
`orthogonal()` to your model and `weight`, and in your model’s forward function,
pass the orthogonalized `weight` to `torch.nn.functional.linear()`. You
can’t use this scheme with `GRU`, as it lacks a functional version.

Third (and probably poor design in my opinion), `orthogonal()` doesn’t expose
its orthogonalization algorithm to the outside world. You can see from the code
for `orthogonal()` that its orthogonalization algorithm is intertwined with its
implementation as a `parametrization`. This makes it impractical to reuse
`orthogonal()`'s internal algorithm in a custom `parametrization` you write
that understands the structure of `GRU`.

I don’t know enough about pytorch’s `svd()` and your use case to know whether
this would be a problem in practice, but it’s certainly a theoretical possibility.
Provided that the relevant rows of `weight_hh_l0` don’t become (nearly)
collinear (which would be a problem for any orthogonalization algorithm), you
should be able to resolve any such sign ambiguity by flipping (if necessary) the
rows of the orthogonalized matrices so that they have positive inner product
with the corresponding rows of the original matrices.

Best.

K. Frank

I’ve implemented a simple series of functions to create parameterizations of GRU weight matrices that are constrained to be orthogonal. Obviously, there is plenty of room for improvement and abstraction to alternative sequential networks (RNN, LSTM), but this is a dirt-simple starting point for that effort.

The overall strategy is to unstack the weight matrices, use Cayley transformations on the individual parts, and then concatenate them together.

``````import torch
from torch import nn
from torch.nn.utils import weight_norm
import torch.nn.utils.parametrize as parametrize

def unpack_gru_weights(X):
nrow, ncol = X.shape

def skew_symmetric(X):
nrow, ncol = X.shape
assert nrow == ncol  # this function assumes a square matrix
A = torch.triu(X, diagonal=1)
# exclude diagonal means A[i,i] = 0 for all i
# main diagonal and all diagonals below are 0, so we can just add.
return A - A.transpose(0, 1)  # subtract s.t. A.t() = - A

def caley_map(X):
A = skew_symmetric(X)
eye = torch.eye(A.shape)

class GruOrthogonal(nn.Module):
def forward(self, X):
# The GRU weight matrices are weight_hh_l{k} = [W_hr|W_hz|W_hn]
#                         and weight_ih_l{k} = [W_ir|W_iz|W_in]
# We have no need to distinguish the 2 cases, so we name them W_xr, W_xz, W_xn
w_xr, w_xz, w_xn = unpack_gru_weights(X)
w_xr = caley_map(w_xr)
w_xz = caley_map(w_xz)
w_xn = caley_map(w_xn)
Xq = torch.cat([w_xr, w_xz, w_xn])
assert X.shape == Xq.shape
return Xq

if __name__ == "__main__":
model_dim = 2
n_batch = 4
max_time = 20
norm_dim = 1

input_tensor = torch.randn((max_time, n_batch, model_dim))

net = nn.GRU(model_dim, model_dim, num_layers=1, batch_first=False)
parametrize.register_parametrization(net, "weight_hh_l0", GruOrthogonal())

net.forward(input_tensor)

Q = net.all_weights
print(Q @ Q.t())
``````

The print gives us a quick visual confirmation that this works: the 2 x 2 sub-matrices along the main diagonal shows that they are 2 x 2 identity matrices, as expected. All other sub-matrices can be ignored, because they are products of sub-matrices that are irrelevant for our goals here. (Numerical imprecision might mean that some of the elements are not precisely 0.)

I elected not to use SVD because (1) I didn’t want to worry about correcting sign indeterminism and (2) I suspect SVD requires the same or greater computational cost, compared to solving the linear system.

1 Like

That’s exactly what I had in mind. I really like the very clean implementation! The only stylistic point I would make is: You may want to have a look at `.mT` which does the same thing as your `transpose` and `t` but in a more concise way (the `m` is for “matrix transpose”).

Ah, I see – since I posted this code, I traded `.transpose` for `.T`. I’ve read the documentation for `.mT` and `.T` and I don’t think `.mT` is a good choice. The documentation states

The use of `Tensor.T()` on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. Consider `mT` to transpose batches of matrices or `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor.

In this usage, we only have 2-tensors, so `.T` is doing exactly what we want. In a scenario where the user supplies a 3-tensor or higher-order tensor, that usage is unanticipated by these methods, so raising an error is the appropriate outcome.

Indeed, overly-general code is how I got here in the first place: `torch.nn.utils.parametrizations.orthogonal` didn’t object when I applied it to the weight tensors of a GRU, so I had to do some very annoying tracing to discover that the out-of-the-box orthogonal parameterization is too generalized & too naïve to be used with GRUs.

Also, I’ve realized that `unpack_gru_weights` is redundant. Instead, I should just do `w_xr, w_xz, w_xn = torch.split(W, W.shape // 3)`. I would edit that post, but for some reason there’s no edit button.

The point is that `.mT` is almost always what you want on higher order tensors, and `.T` is almost never what you want on higher order. But of course, to each their own Note that `parametrizations.orthogonal` is not overly general. It just has no way to know that what you want is not to make the whole matrix orthogonal but split it in three and then make each of them orthogonal.

I think if you re-read my post, you’ll see that is exactly what I’m saying. The weight in `nn.GRU` are 2-tensors, so splitting them also produces 2-tensors, which is why using `T` is preferable to `mT`.

Well, it’s definitely over-general because part of the mathematical definition of an orthogonal matrix is that an orthogonal matrix is square. Users can apply `torch.nn.utils.parametrizations.orthogonal` to non-square matrices; indeed, this is the origination of the unexpected behavior outlined in my first post.

Your reply is consistent with what I mean by naïve. If there’s no intention to make the out-of-the-box function `torch.nn.utils.parametrizations.orthogonal` compatible with the out-of-the-box classes `nn.GRU` and `nn.LSTM`, then there should be a note to that effect somewhere in the docs. Or it should raise an error when it is called on a non-square matrix, which would be consistent with the mathematical definition of orthogonal.

Even better, an alternative way to make `parametrizations.orthogonal` compatible with native pytorch functionality would be to simply not store the weight matrices in this concatenated format! Then it would work as advertised, without users being compelled to write their own parameterization code.

If you think that any of the implementations in core can be improved, we always accept PRs. Now, I don’t think that what you propose can be achieved without incurring on a perf hit on the most common use case for LSTM and GRU, that is, when they do not have any parametrisation registered.

I created an issue on the PyTorch github last month. It outlines several alternative paths to make the orthogonal parameterization clearer in the case of non-square matrices. parameterizations.orthogonal does not work as intended with nn.GRU or nn.LSTM · Issue #102740 · pytorch/pytorch · GitHub