How do I load ONLY the weights of the Hidden state in an LSTM to another LSTM before training?

So I have a trained LSTM and I want to load ONLY its rnn.weight_hh_l[k] to a new LSTM? I don’t want to load the entire state_dict, just that set of weights

Can this do the trick?

import torch

emb_dim = 1
hidden_dim = 2
n_layers = 2
bidirectional = False
dropout = 0.1

pattern = "weight_hh_l"


torch.manual_seed(0)
model1 = torch.nn.LSTM(emb_dim, hidden_dim, num_layers = n_layers, bidirectional=bidirectional, dropout = 0 if n_layers < 2 else dropout)

{k : v for k, v in model1.state_dict().items() if k.startswith(pattern)}
"""
{'weight_hh_l0': tensor([[-0.0628,  0.1871],
         [-0.2137, -0.1390],
         [-0.6755, -0.4683],
         [-0.2915,  0.0262],
         [ 0.2795,  0.4243],
         [-0.4794, -0.3079],
         [ 0.2568,  0.5872],
         [-0.1455,  0.5291]]), 'weight_hh_l1': tensor([[ 0.1318, -0.5482],
         [-0.4901, -0.3653],
         [ 0.3199,  0.2844],
         [-0.4189,  0.2136],
         [ 0.3882, -0.0892],
         [ 0.0270,  0.1638],
         [ 0.4387,  0.6790],
         [-0.5449, -0.2591]])}
"""


torch.manual_seed(1)
model2 = torch.nn.LSTM(emb_dim, hidden_dim, num_layers = n_layers, bidirectional=bidirectional, dropout = 0 if n_layers < 2 else dropout)

{k : v for k, v in model2.state_dict().items() if k.startswith(pattern)}
"""
{'weight_hh_l0': tensor([[ 0.0983, -0.0866],
         [ 0.1961,  0.0349],
         [ 0.2583, -0.2756],
         [-0.0516, -0.0637],
         [ 0.1025, -0.0028],
         [ 0.6181,  0.2200],
         [-0.2633, -0.4271],
         [-0.1185, -0.3050]]), 'weight_hh_l1': tensor([[-0.0331, -0.4720],
         [ 0.4306,  0.2195],
         [-0.4571,  0.4593],
         [ 0.4293,  0.6271],
         [-0.3964, -0.1164],
         [-0.0137,  0.1033],
         [-0.5366, -0.5018],
         [ 0.3847, -0.1658]])}
"""

state_dict = {}
for item1, item2 in zip(model1.state_dict().items(), model2.state_dict().items()) :
    key = item1[0] # = item2[0] 
    if key.startswith(pattern) : # or another condition ...
        state_dict[key] = item1[1] # take one of model1
    else :
        state_dict[key] = item2[1] # take one of model2


model2.load_state_dict(state_dict)
"""
<All keys matched successfully>
"""


#assert {k : v for k, v in model1.state_dict().items() if k.startswith(pattern)} "costum =" {k : v for k, v in model2.state_dict().items() if k.startswith(pattern)}
"""
True
"""

Hello,

I managed to implement a solution that is different to yours, but I will see if yours works too. Could you check if there are any obvious errors in my solution?

torch.manual_seed(42)

model_NA1 = LSTM(1, 10, 10, 1)
model_NA2 = LSTM(1, 10, 10, 1)

model_NA1.to(device)
model_NA2.to(device)

learning_rate = 0.0001

optimizer1 = torch.optim.Adam(model_NA1.parameters(), lr = learning_rate)
optimizer2 = torch.optim.Adam(model_NA2.parameters(), lr = learning_rate)

no_epochs = 1000

losses1 = np.zeros(no_epochs)
losses2 = np.zeros(no_epochs)

test_losses1 = np.zeros(no_epochs)
test_losses2 = np.zeros(no_epochs)

for i in range(no_epochs):

    optimizer1.zero_grad()
    optimizer2.zero_grad()

    # Forward pass
    output1 = model_NA1(x_train_model)

    loss1 = loss_function(output1, y_train_model)

    loss1.backward()
    optimizer1.step()

    losses1[i] = loss1.item()

    torch.save(model_NA1.state_dict(), 'W_hh.pt')
    state_dict = torch.load('W_hh.pt')

    with torch.no_grad():
        model_NA2.rnn.weight_hh_l0.copy_(state_dict['rnn.weight_hh_l0'])
        model_NA2.rnn.weight_hh_l1.copy_(state_dict['rnn.weight_hh_l1'])
        model_NA2.rnn.weight_hh_l2.copy_(state_dict['rnn.weight_hh_l2'])
        model_NA2.rnn.weight_hh_l3.copy_(state_dict['rnn.weight_hh_l3'])
        model_NA2.rnn.weight_hh_l4.copy_(state_dict['rnn.weight_hh_l4'])
        model_NA2.rnn.weight_hh_l5.copy_(state_dict['rnn.weight_hh_l5'])
        model_NA2.rnn.weight_hh_l6.copy_(state_dict['rnn.weight_hh_l6'])
        model_NA2.rnn.weight_hh_l7.copy_(state_dict['rnn.weight_hh_l7'])
        model_NA2.rnn.weight_hh_l8.copy_(state_dict['rnn.weight_hh_l8'])
        model_NA2.rnn.weight_hh_l9.copy_(state_dict['rnn.weight_hh_l9'])

    #model_NA2.load_state_dict('W_hh.pt', strict = False)

    output2 = model_NA2(x_train_model) 

    loss2 = loss_function(output2, y_train_obs)

    loss2.backward()
    optimizer2.step()

    losses2[i] = loss2.item()

    with torch.no_grad():
        test_outputs1 = model_NA1(x_val_model)
        test_outputs2 = model_NA2(x_val_model)
        test_loss1 = loss_function(test_outputs1, y_val_model)
        test_loss2 = loss_function(test_outputs2, y_val_obs)
        
        test_losses1[i] = test_loss1.item()
        test_losses2[i] = test_loss2.item()


    if (i + 1) %100 == 0:
      print(f'Epoch {i+1}/{no_epochs}, NA1 Loss: {loss1.item():.4f}, NA2 Loss: {loss2.item():.4f}')
      #print(f'Epoch {i+1}/{no_epochs}, NA1 Test Loss: {test_loss1.item():.4f}, NA2 Test Loss: {test_loss2.item():.4f}')

I am slightly concered in that when I run this consecutive times, it gives me different out of sample accuracy metrics. Any thoughts why this may be?

Of course your solution works, but it is agnostic to the model’s hyperparameters, and so for each model you will have to rework the code.
Moreover your code becomes very long (and therefore difficult to implement) if the values of these hyperparameters become very large (which is generally the case in practice)

I think this function is enough (Instead of select the parameters name by yourself…)

def get_state_dic(model1, model2, pattern) :
	state_dict = {}
	for item1, item2 in zip(model1.state_dict().items(), model2.state_dict().items()) :
	    key = item1[0] # = item2[0] 
	    if key.startswith(pattern) : # or another condition ...
	        state_dict[key] = item1[1] # take one of model1
	    else :
	        state_dict[key] = item2[1] # take one of model2
	return stat_dict
1 Like

That is very helpful, I will try to add this in later tonight. I shall update you with how it works.

Thank you!