Can I load a model with the same layer parameters but design my own forward function?

I designed a model that has two different kinds of inputs: x1 and x2. x2 will only pass through part of the model and its output is used for some constraints. I saved the model after training and try to reuse it in the test stage. However, I only need input x1 during the test stage. Is there any way that I can load the layer parameters of the saved model but with different contents in the “forward” function? It seems that torch.load(model.state_dict) cannot give me satisfying results since it requires the same forward function.

Thanks very much!

The forward method is independent from the state_dict and you should be able to manipulate this method as you wish.
state_dict = torch.load(path_to_state_dict) will only load all parameters and buffers in an OrderedDict without any knowledge about the forward method.
Do you see any error using this approach?

1 Like

Yes, I use the commands as below:

model_path = os.path.join(args.checkpoint,‘model_epoch_{}.pth’.format(str(ep)))
checkpoint = torch.load(model_path)
model = Unet_BiCLSTM()

where Unet_BiCLSTM() is my network which only have different forward function as the pretrained model. Then I obtain mistakes as below:

RuntimeError: Error(s) in loading state_dict for Unet_BiCLSTM:
Missing key(s) in state_dict: “feature_ext.conv0.weight”, “feature_ext.norm0.weight”, “feature_ext.norm0.bias”, “feature_ext.norm0.running_mean”, “feature_ext.norm0.running_var”, …

Is there something wrong about it?

Thanks very much!

You could use strict=false in the state_dict, that might help - if I understand correctly.

Edit: But that would mean that the relevant weights aren’t loaded at all since the model expects these layers to be present… are the layers named the same? Can you load them back into the original model? Or do you mean by ‘different contents in the forward function’ that you’re changing the model architecture?

Yeah, I can load them back to the model. I copied the contents in the original model, and change the contents in the forward function by reducing one input, and removing the corresponding line in the forward function which uses the same layer as other inputs. During the training, this input is used to pass through a part of the network and its output is constrained to provide some regularization to the network. In the test stage, this process is not used anymore, so I want to remove it by changing the forward function.