Can I set values in a model's state_dict/ a layers dropout rate after it has been loaded from a checkpoint?

Summary: I’d like to load a model from a .pt file that was created with a dropout of x, and set a new value y for the dropout and continue training the model with the new dropout rate.

I have a trained model I save in a .pt file. I load the model and optimizer like this:

checkpoint = torch.load('./models/')

In my model class, I create an LSTM layer like this:

self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, self.n_layers, dropout=dropout_rate)

When initially created, the model had a dropout value of 0 passed. In the code to load and retrain, I get the value for dropout_rate from the cmd line arguments, and set a global variable that is referenced in the class definition (not sure if this could cause the issue, and dropout_rate needs to be passed as an argument to the class instead). If I call the code to retrain and pass a value of 0.5 as dropout rate, it seems that loading the model from the .pt file overwrites the parameters passed when creating the model object.

It is possible to set the dropout rate of the LSTM layer to a new value after the model has been loaded from a .pt file?

Running this code nets me the following output:

for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

Model's state_dict:
lstm.weight_ih_l0 	 torch.Size([300, 177])
lstm.weight_hh_l0 	 torch.Size([300, 75])
lstm.bias_ih_l0 	 torch.Size([300])
lstm.bias_hh_l0 	 torch.Size([300])
linear.weight 	 torch.Size([1, 75])
linear.bias 	 torch.Size([1])

Printing model.parameters() provides no useful info, as it is a list of nameless parameters mapped to large arrays of tensors. I don’t see where I could find the value of the dropout, as it must be stored somewhere if it is overwriting the initial value.

Is this at all possible?

1 Like

Setting the dropout after initializing the model should work, but you would have to make sure, the passed value makes sense for the current setup, as this check will be missing.

I’m not too familiar with loading a model directly from a file, as this might have side effects, but I assume as long as you can successfully create an instance of your model, the workflow should still work.

Did this work for you?