How to initialize weights of LSTMcell?

(playma) #1

I am new to Pytorch, and do not know how to initialize the trainable parameters of nn.LSTMcell.
I want to use nn.init.orthogonal to initialize nn.LSTMcell



This thread might have the solution.

(Søren Koch) #3
# Initializing and setting the variance of a tensor of weights
def normalized_columns_initializer(weights, std=1.0):
    out = torch.randn(weights.size())
    out *= std / torch.sqrt(out.pow(2).sum(1,keepdim=True).expand_as(out)) # thanks to this initialization, we have var(out) = std^2
    return out

# Initializing the weights of the neural network in an optimal way for the learning
def weights_init(m):
    classname = m.__class__.__name__ # python trick that will look for the type of connection in the object "m" (convolution or full connection)
    if classname.find('Linear') != -1:
        weight_shape = list( #?? list containing the shape of the weights in the object "m"
        fan_in = weight_shape[1] # dim1
        fan_out = weight_shape[0] # dim0
        w_bound = np.sqrt(6. / (fan_in + fan_out)) # weight bound, w_bound) # generating some random weights of order inversely proportional to the size of the tensor of weights # initializing all the bias with zeros

# Creating the architecture of the Neural Network
class LSTM_QNETWORK(nn.Module): #inherinting from nn.Module
    def __init__(self, input_size, nb_action): #[self,input neuroner, output neuroner]
        super(LSTM_QNETWORK, self).__init__() #inorder to use modules in torch.nn
        # Input and output neurons
        self.lstm = nn.LSTMCell(input_size, 30) # making an LSTM (Long Short Term Memory) to learn the temporal properties of the input
        self.fcL = nn.Linear(30, nb_action) # full connection of the
        self.apply(weights_init) # initilizing the weights of the model with random weights = normalized_columns_initializer(, 0.01) # setting the standard deviation of the fcL tensor of weights to 0.01 # initializing the actor bias with zeros # initializing the lstm bias with zeros # initializing the lstm bias with zeros
        #self.train() # setting the module in "train" mode to activate the dropouts and batchnorms
    # For function that will activate neurons and perform forward propagation
    def forward(self, inputs, learn_state):
        state, (hx, cx) = inputs 
        hx, cx = self.lstm(state, (hx, cx)) # the LSTM takes as input x and the old hidden & cell states and ouputs the new hidden & cell states
        x = hx # getting the useful output, which are the hidden states (principle of the LSTM)
        q_values = self.fcL(x)
        if learn_state is False:
            return q_values, (hx, cx)
            return q_values