Initializing RNN, GRU and LSTM correctly

For what I see pytorch initializes every weight in the sequence layers with a normal distribution, I dont know how biases are initialized.
Can someone tell me how to proper initialize one of this layers, such as GRU? I am looking for the same initialization that keras uses: zeros for the biases, xavier_uniform for the input weights, orthogonal for the recurrent weights.

Thanks in advance!

2 Likes

I am using this initialization, any mistakes here?

    def init_weights(self):
        for m in self.modules():
            if type(m) in [nn.GRU, nn.LSTM, nn.RNN]:
                for name, param in m.named_parameters():
                    if 'weight_ih' in name:
                        torch.nn.init.xavier_uniform_(param.data)
                    elif 'weight_hh' in name:
                        torch.nn.init.orthogonal_(param.data)
                    elif 'bias' in name:
                        param.data.fill_(0)
8 Likes

You have to pay attention to the weight_hh and weight_ih, because they are a concatenation of four different sub matrices. Thus the desired behaviour you want isn’t actually happening. You will have to do a init.orthogonal_(sub_matrix) for each four of them.

1 Like

Can you provide the code for a correct initialization please?

´for idx in range(4):
\t mul = param.shape[0]//4
\t torch.nn.init.xavier_uniform_(param[idx*mul:(idx+1)*mul]
´
for every “weight_ih” and “weight_hh”

sorry, but i don’t know how to edit code in here.

1 Like

what about the orthogonal initialization??

The main point is that the eight matrices of lstm are grouped into two groups, reason being performance, so you have “weight_ih”, which is a concatenation of “(W_ii|W_if|W_ig|W_io)” and “weight_hh”, is of “(W_hi|W_hf|W_hg|W_ho)”. So you do the orthogonal initialization to the sub matrices of “weight_hh” and the xavier to the sub matrices of “weight_ih”.

Initialize each one of the weight matrices as an identity for the hidden-hidden weight, and then stack them. My question in when I apply the torch.nn.init.orthogonal_ this makes the seperate matrices orthogonal (hidden_size,hidden_size) or makes the general matrix orthogonal(the concatenation of the others) with dimensions (4*hidden_size,hidden_size)?

#INITIALIZE HIDDEN STATES
        def initHidden(self):  
              
                for value in self.lstm_decode.state_dict():
                
                        #format values
                        param = self.lstm_decode.state_dict()[value]
                        if 'weight_ih' in value:
                                #print(value,param.shape,'Orthogonal')
                                torch.nn.init.orthogonal_(self.lstm_decode.state_dict()[value])#input TO hidden ORTHOGONALLY || Wii, Wif, Wic, Wio
                        elif 'weight_hh' in value:
                                #INITIALIZE SEPERATELY EVERY MATRIX TO BE THE IDENTITY AND THE STACK THEM                        
                                weight_hh_data_ii = torch.eye(self.hidden_units,self.hidden_units)#H_Wii
                                weight_hh_data_if = torch.eye(self.hidden_units,self.hidden_units)#H_Wif
                                weight_hh_data_ic = torch.eye(self.hidden_units,self.hidden_units)#H_Wic
                                weight_hh_data_io = torch.eye(self.hidden_units,self.hidden_units)#H_Wio
                                weight_hh_data = torch.stack([weight_hh_data_ii,weight_hh_data_if,weight_hh_data_ic,weight_hh_data_io], dim=0)
                                weight_hh_data = weight_hh_data.view(self.hidden_units*4,self.hidden_units)
                                #print(value,param.shape,weight_hh_data.shape,self.number_of_layers,self.hidden_units,'Identity')
                                self.lstm_decode.state_dict()[value].data.copy_(weight_hh_data)#hidden TO hidden IDENTITY.state_dict()[value].data.copy_(weight_hh_data)#hidden TO hidden IDENTITY
                        elif 'bias' in value:
                                #print(value,param.shape,'Zeros')
                                torch.nn.init.constant_(self.lstm_decode.state_dict()[value], val=0)
                                self.lstm_decode.state_dict()[value].data[self.hidden_units:self.hidden_units*2].fill_(1)#set the forget gate | (b_ii|b_if|b_ig|b_io)
2 Likes