LSTM <-> LSTMCell conversion helper functions

I’m not sure if others would find these useful, even if only as a sort for beginner tutorial/helper, but I’ve written two quick and simple (as-is) functions to convert to and from a ModuleList of LSTMCells and an LSTM object.

There’s a few assumptions being made here (namely that the state dictionaries will always keep the same names).

def convert_lstm_to_lstm_cells(lstm):
	lstm_cells = nn.ModuleList([nn.LSTMCell( lstm.input_size , lstm.hidden_size )] +
							   ([nn.LSTMCell( lstm.hidden_size , lstm.hidden_size)] * (lstm.num_layers - 1)))
	
	key_names = lstm_cells[0].state_dict().keys()
	source = lstm.state_dict()
	for i in range( lstm.num_layers ):
		new_dict = OrderedDict( [(k, source["%s_l%d" % (k, i)]) for k in key_names] )
		lstm_cells[i].load_state_dict( new_dict )
	
	return lstm_cells

def convert_lstm_cells_to_lstm(lstm_cells):
	lstm = nn.LSTM( lstm_cells[0].input_size , lstm_cells[0].hidden_size, len(lstm_cells) )
	
	key_names = lstm_cells[0].state_dict().keys()
	lstm_dict = OrderedDict()
	for i, lstm_cell in enumerate(lstm_cells):
		source = lstm_cell.state_dict()
		new_dict = OrderedDict( [("%s_l%d" % (k, i), source[k]) for k in key_names] )
		lstm_dict = OrderedDict(list(lstm_dict.items()) + list(new_dict.items()))
	lstm.load_state_dict( lstm_dict )
	
	return lstm

The need for such a conversion came about because of issues like the ones in the following posts:

Criticisms and comments are welcome.

4 Likes