GRU weight matrices initialization

Hi,

I currently trying to figure out how to correctly initialize GRU/GRUCell weight matrices, and spot that the shape of those matrices is the concatenation of the reset/update/new gates resulting in a shape of 3 * hidden_size for both the input to hidden and hidden to hidden.

I took a look at the reset_parameters() method, found in the GRUCell code, and spot the variance of the initializer is computer over the hidden size, thus returning coherent results.

Then, when trying to apply an orthogonal and/or Xavier init over those matrices, I was wondering if they should be chunked to allow PyTorch to correctly compute the fan_in/out ?

Here is a snippet of what I currently thinking, and would like to acknowledge this is the right way to do so :

   @staticmethod
    def weights_init(x):
        if isinstance(x, GRU):
            for n, p in x.named_parameters():
                if 'weight_ih' in n:
                    for ih in p.chunk(3, 0):
                        torch.nn.init.xavier_uniform_(ih)
                elif 'weight_hh' in n:
                    for hh in p.chunk(3, 0):
                        torch.nn.init.orthogonal_(hh)
                elif 'bias_ih' in n:
                    torch.nn.init.zeros_(p)
                # elif 'bias_hh' in n:
                #     torch.nn.init.ones_(p)

        elif isinstance(x, GRUCell):
            for hh, ih in zip(x.weight_hh.chunk(3, 0), x.weight_ih.chunk(3, 0)):
                torch.nn.init.orthogonal_(hh)
                torch.nn.init.xavier_uniform_(ih)

            torch.nn.init.zeros_(x.bias_ih)