I am attempting to create a GRU that uses orthogonal parameterization for the weight matrices stored in `weight_hh_l0`

. From the GRU documentation, I know that `weight_hh_l0`

contains three weight matrices `W_hr`

,`W_hz`

, and `W_hn`

concatenated together. The problem that I’m facing is that because the GRU module stores the matrices in a concatenated format, applying the orthogonal parameterization via `torch.nn.utils.parametrizations.orthogonal`

yields a matrix that is entirely 0 for the bottom two-thirds of the rows.

My question is, **what is the correct way to enforce orthogonality constraints on GRU weight matrices so that each of the submatrices of the GRU weights are orthogonal?** I want

`W_hr`

, `W_hz`

, and `W_hn`

to all be orthogonal, instead of `W_hr`

orthogonal with the other matrices set to the zero matrix.This is a minimal example demonstrating what I’m observing.

```
import torch
from torch import nn
from torch.nn.utils.parametrizations import orthogonal
model_dim = 2
n_batch = 4
max_time = 20
input_tensor = torch.randn((max_time, n_batch, model_dim))
net = nn.GRU(model_dim, model_dim, num_layers=1, batch_first=False)
net = orthogonal(net, name=f"weight_hh_l0")
net.forward(input_tensor)
for name, param in net.named_parameters():
print(name)
print(param)
```

The result is that the weight matrix `weight_hh_l0`

is only orthogonal in the first third (corresponding to `W_hr`

), but the latter two-thirds (corresponding to `W_hz`

, and `W_hn`

) are all zeros. **This is not what I want, because it does not contain three orthogonal sub-matrices**. In other words, it seems that `orthogonal`

is naïve about `weight_hh_l0`

and treats it as a single weight matrix, instead of three independent matrices concatenated together, each of which has an orthogonality constraint.

```
parametrizations.weight_hh_l0.original
Parameter containing:
tensor([[-1., 0.],
[ 0., -1.],
[ 0., 0.],
[ 0., 0.],
[ 0., 0.],
[ 0., 0.]], requires_grad=True)
```

Is there a way that we can enforce orthogonality of all three matrices `W_hr`

, `W_hz`

, and `W_hn`

using `torch.nn.GRU`

and `torch.nn.utils.parametrizations.orthogonal`

? How so?