Defining weight manually for LSTM

Recently I was diving into meta-learning, and need to change the weights of module during the training process, so I can’t use off-the-shelf torch.nn.Conv2d or torch.nn.LSTM module for I can’t pass weights into the module. Instead, I have to define weights manually and call the underlying interface.

For convolution layers or batch normalization layers, PyTorch provides torch.nn.Functional.conv2d and torch.nn.Functional.batchnorm interface, and can be called easily. Things are a little different for LSTM module, there is no interface like torch.nn.Functional.LSTM.

So I looked up the doc of torch.nn.LSTM module, and found a interface torch.nn._VF.lstm. I just call this interface and pass my self-defined weight to it, and the code actually runs normally. However, I found that the training result is worse than the result trained using the torch.nn.LSTM module (I got 80% accuracy for a text recognition task using the LSTM module while 70% using _VF.lstm interface). So I think there must be something I didn’t notice, can anybody provide me some advice? What’s the problem construct a LSTM layer like this?

Thanks a lot!

Here is the key part of my code, since it is a demo version, I only considered the single layer case, and the weight has bias. Some variable names are changed for the readability, the original code can be run normally.

The first part is the initialization of weights, this is called when the network was initialized.

def create_lstm_weight(self, device):
        import math
        param_list = [nn.Parameter(torch.ones((4* hidden_size, intput_size)).to(device)), # W_ih
                      nn.Parameter(torch.ones((4* hidden_size, intput_size)).to(device)), # W_hh
                      nn.Parameter(torch.ones((4* hidden_size)).to(device)), # b_ih
                      nn.Parameter(torch.ones((4* hidden_size)).to(device))] # b_hh
        if bi_direction:
            param_list.extend([nn.Parameter(torch.ones((4* hidden_size, intput_size)).to(device)),# W_ih_reverse
                      nn.Parameter(torch.ones((4* hidden_size, intput_size)).to(device)),# W_hh_reverse
                      nn.Parameter(torch.ones((4* hidden_size)).to(device)), # b_ih_reverse
                      nn.Parameter(torch.ones((4* hidden_size)).to(device))]) # b_hh_reverse
        # flatten the weights as described in doc
        if param_list[0].is_cuda and torch.backends.cudnn.is_acceptable(param_list[0]):
            with torch.cuda.device_of(param_list[0]):
                import torch.backends.cudnn.rnn as rnn 
                with torch.no_grad(): # 
                   torch._cudnn_rnn_flatten_weight(param_list, (4 if has_bias else 2),
                        input_size, rnn.get_cudnn_mode('LSTM'), hidden_size, num_layers=1, batch_first=False, bidirectional=True)
        # initialize the weights
        for p in param_list:
            torch.nn.init.uniform_(p, a=math.sqrt(1 / hidden_size) * -1, b=math.sqrt(1 / hidden_size))

The second part is the forward method, it is called in the forward method of the the network

def lstm_forward(self, x, param): 
        x: [time_step_length, batch_size, feature_dim]
        time_step, batch_size, input_size = x.shape
        if bidirectional: 
            h_state = (torch.zeros(2, batch_size, hidden_size, device=self.device, dtype=torch.float32), torch.zeros(2, batch_size, hidden_size, device=self.device, dtype=torch.float32))
            weights = param
            h_state = (torch.zeros(1, batch_size, hidden_size, device=self.device, dtype=torch.float32), torch.zeros(1, batch_size, hidden_size, device=self.device, dtype=torch.float32))
            weights = param

        result = _VF.lstm(x, h_state, weights, use_bias=True, num_layers=1, dropout_rate=0.0, training=True, bidirectional=True, batch_first=False 
        outputs, h = result[0], result[1:]

        return output, h

I am also interested in this because I am having the same problem as you right now.
I was looking for the functional version of torch.nn.LSTM() but found one there is none. So digging inside, I found the _VF.lstm() function. Did you end up figuring it out what was the problem you were having?