I am attempting to create a GRU that uses orthogonal parameterization for the weight matrices stored in weight_hh_l0
. From the GRU documentation, I know that weight_hh_l0
contains three weight matrices W_hr
,W_hz
, and W_hn
concatenated together. The problem that I’m facing is that because the GRU module stores the matrices in a concatenated format, applying the orthogonal parameterization via torch.nn.utils.parametrizations.orthogonal
yields a matrix that is entirely 0 for the bottom two-thirds of the rows.
My question is, what is the correct way to enforce orthogonality constraints on GRU weight matrices so that each of the submatrices of the GRU weights are orthogonal? I want W_hr
, W_hz
, and W_hn
to all be orthogonal, instead of W_hr
orthogonal with the other matrices set to the zero matrix.
This is a minimal example demonstrating what I’m observing.
import torch
from torch import nn
from torch.nn.utils.parametrizations import orthogonal
model_dim = 2
n_batch = 4
max_time = 20
input_tensor = torch.randn((max_time, n_batch, model_dim))
net = nn.GRU(model_dim, model_dim, num_layers=1, batch_first=False)
net = orthogonal(net, name=f"weight_hh_l0")
net.forward(input_tensor)
for name, param in net.named_parameters():
print(name)
print(param)
The result is that the weight matrix weight_hh_l0
is only orthogonal in the first third (corresponding to W_hr
), but the latter two-thirds (corresponding to W_hz
, and W_hn
) are all zeros. This is not what I want, because it does not contain three orthogonal sub-matrices. In other words, it seems that orthogonal
is naïve about weight_hh_l0
and treats it as a single weight matrix, instead of three independent matrices concatenated together, each of which has an orthogonality constraint.
parametrizations.weight_hh_l0.original
Parameter containing:
tensor([[-1., 0.],
[ 0., -1.],
[ 0., 0.],
[ 0., 0.],
[ 0., 0.],
[ 0., 0.]], requires_grad=True)
Is there a way that we can enforce orthogonality of all three matrices W_hr
, W_hz
, and W_hn
using torch.nn.GRU
and torch.nn.utils.parametrizations.orthogonal
? How so?