Hi,
I currently trying to figure out how to correctly initialize GRU/GRUCell weight matrices, and spot that the shape of those matrices is the concatenation of the reset/update/new gates resulting in a shape of 3 * hidden_size for both the input to hidden and hidden to hidden.
I took a look at the reset_parameters() method, found in the GRUCell code, and spot the variance of the initializer is computer over the hidden size, thus returning coherent results.
Then, when trying to apply an orthogonal and/or Xavier init over those matrices, I was wondering if they should be chunked to allow PyTorch to correctly compute the fan_in/out ?
Here is a snippet of what I currently thinking, and would like to acknowledge this is the right way to do so :
@staticmethod
def weights_init(x):
if isinstance(x, GRU):
for n, p in x.named_parameters():
if 'weight_ih' in n:
for ih in p.chunk(3, 0):
torch.nn.init.xavier_uniform_(ih)
elif 'weight_hh' in n:
for hh in p.chunk(3, 0):
torch.nn.init.orthogonal_(hh)
elif 'bias_ih' in n:
torch.nn.init.zeros_(p)
# elif 'bias_hh' in n:
# torch.nn.init.ones_(p)
elif isinstance(x, GRUCell):
for hh, ih in zip(x.weight_hh.chunk(3, 0), x.weight_ih.chunk(3, 0)):
torch.nn.init.orthogonal_(hh)
torch.nn.init.xavier_uniform_(ih)
torch.nn.init.zeros_(x.bias_ih)