I’m not sure if others would find these useful, even if only as a sort for beginner tutorial/helper, but I’ve written two quick and simple (as-is) functions to convert to and from a ModuleList of LSTMCells and an LSTM object.
There’s a few assumptions being made here (namely that the state dictionaries will always keep the same names).
def convert_lstm_to_lstm_cells(lstm):
lstm_cells = nn.ModuleList([nn.LSTMCell( lstm.input_size , lstm.hidden_size )] +
([nn.LSTMCell( lstm.hidden_size , lstm.hidden_size)] * (lstm.num_layers - 1)))
key_names = lstm_cells[0].state_dict().keys()
source = lstm.state_dict()
for i in range( lstm.num_layers ):
new_dict = OrderedDict( [(k, source["%s_l%d" % (k, i)]) for k in key_names] )
lstm_cells[i].load_state_dict( new_dict )
return lstm_cells
def convert_lstm_cells_to_lstm(lstm_cells):
lstm = nn.LSTM( lstm_cells[0].input_size , lstm_cells[0].hidden_size, len(lstm_cells) )
key_names = lstm_cells[0].state_dict().keys()
lstm_dict = OrderedDict()
for i, lstm_cell in enumerate(lstm_cells):
source = lstm_cell.state_dict()
new_dict = OrderedDict( [("%s_l%d" % (k, i), source[k]) for k in key_names] )
lstm_dict = OrderedDict(list(lstm_dict.items()) + list(new_dict.items()))
lstm.load_state_dict( lstm_dict )
return lstm
The need for such a conversion came about because of issues like the ones in the following posts:
Criticisms and comments are welcome.