Here is the key part of my code, since it is a demo version, I only considered the single layer case, and the weight has bias. Some variable names are changed for the readability, the original code can be run normally.
The first part is the initialization of weights, this is called when the network was initialized.
def create_lstm_weight(self, device):
import math
param_list = [nn.Parameter(torch.ones((4* hidden_size, intput_size)).to(device)), # W_ih
nn.Parameter(torch.ones((4* hidden_size, intput_size)).to(device)), # W_hh
nn.Parameter(torch.ones((4* hidden_size)).to(device)), # b_ih
nn.Parameter(torch.ones((4* hidden_size)).to(device))] # b_hh
if bi_direction:
param_list.extend([nn.Parameter(torch.ones((4* hidden_size, intput_size)).to(device)),# W_ih_reverse
nn.Parameter(torch.ones((4* hidden_size, intput_size)).to(device)),# W_hh_reverse
nn.Parameter(torch.ones((4* hidden_size)).to(device)), # b_ih_reverse
nn.Parameter(torch.ones((4* hidden_size)).to(device))]) # b_hh_reverse
# flatten the weights as described in doc
if param_list[0].is_cuda and torch.backends.cudnn.is_acceptable(param_list[0]):
with torch.cuda.device_of(param_list[0]):
import torch.backends.cudnn.rnn as rnn
with torch.no_grad(): #
torch._cudnn_rnn_flatten_weight(param_list, (4 if has_bias else 2),
input_size, rnn.get_cudnn_mode('LSTM'), hidden_size, num_layers=1, batch_first=False, bidirectional=True)
# initialize the weights
for p in param_list:
torch.nn.init.uniform_(p, a=math.sqrt(1 / hidden_size) * -1, b=math.sqrt(1 / hidden_size))
The second part is the forward method, it is called in the forward method of the the network
def lstm_forward(self, x, param):
'''
x: [time_step_length, batch_size, feature_dim]
'''
time_step, batch_size, input_size = x.shape
if bidirectional:
h_state = (torch.zeros(2, batch_size, hidden_size, device=self.device, dtype=torch.float32), torch.zeros(2, batch_size, hidden_size, device=self.device, dtype=torch.float32))
weights = param
else:
h_state = (torch.zeros(1, batch_size, hidden_size, device=self.device, dtype=torch.float32), torch.zeros(1, batch_size, hidden_size, device=self.device, dtype=torch.float32))
weights = param
result = _VF.lstm(x, h_state, weights, use_bias=True, num_layers=1, dropout_rate=0.0, training=True, bidirectional=True, batch_first=False
outputs, h = result[0], result[1:]
return output, h