I currently trying to figure out how to correctly initialize GRU/GRUCell weight matrices, and spot that the shape of those matrices is the concatenation of the reset/update/new gates resulting in a shape of 3 * hidden_size for both the input to hidden and hidden to hidden.
I took a look at the reset_parameters() method, found in the GRUCell code, and spot the variance of the initializer is computer over the hidden size, thus returning coherent results.
Then, when trying to apply an orthogonal and/or Xavier init over those matrices, I was wondering if they should be chunked to allow PyTorch to correctly compute the fan_in/out ?
Here is a snippet of what I currently thinking, and would like to acknowledge this is the right way to do so :
@staticmethod def weights_init(x): if isinstance(x, GRU): for n, p in x.named_parameters(): if 'weight_ih' in n: for ih in p.chunk(3, 0): torch.nn.init.xavier_uniform_(ih) elif 'weight_hh' in n: for hh in p.chunk(3, 0): torch.nn.init.orthogonal_(hh) elif 'bias_ih' in n: torch.nn.init.zeros_(p) # elif 'bias_hh' in n: # torch.nn.init.ones_(p) elif isinstance(x, GRUCell): for hh, ih in zip(x.weight_hh.chunk(3, 0), x.weight_ih.chunk(3, 0)): torch.nn.init.orthogonal_(hh) torch.nn.init.xavier_uniform_(ih) torch.nn.init.zeros_(x.bias_ih)