How can I make PyTorch save all the weights from all the sub-layers the model is composed of?

Dear PyTorch community,

today I encountered a bug which baffled me for quite a while. It’s related to saving the trained model weights, but I don’t know how to fix it, so that’s why I’m asking here.

I’m implementing Transformer from scratch in PyTorch. I first wrote the code for the lowest layers of the Transformer, such as Scaled Dot-Product Attention and Multi-Head Attention. Then I went on to create the Encoder and the Decoder blocks. Then I created the Transformer Encoder and Transformer Decoder classes. Finally, I created the Transformer class.

I am able to train the model to fit a small batch perfectly. After I train the model for some number of epochs, I save it with the following code (as explained in the Quickstart and Saving and Loading Models):

weights_filename = "transformer_model_trained_on_mini_train_dataset_weights_after_" + str(number_of_epochs) + "_epochs.pth"
torch.save(TransformerInstance.state_dict(), weights_filename)

Then later, I load it with:

TransformerInstance.load_state_dict(torch.load("transformer_model_trained_on_mini_train_dataset_weights_after_2048_epochs.pth"))

Here’s the problem: The saving only saves the weights of the layers specific to Transformer Encoder and Transformer Decoder. It doesn’t save the weights recursively. So what I’m saying is is that, while a layer that is within the Transformer Encoder for example will be saved, the Scaled Dot-Product Attention layer weights will not. To repeat, the Transformer contains the Transformer Encoder, the Transformer Encoder contains the Encoder block, the Encoder block contains the Multi-Head Attention layer and finally, the Multi-Head Attention layer contains the Scaled Dot-Product Attention layer. This is just a specific example, the point is that the weights don’t get saved “recursively”.

I think the following output will make the problem clear:

Before loading:
TransformerInstance.TransformerEncoderInstance.encoderBlocks[0].MultiHeadAttentionLayer.scaled_dot_product_attention_layers[0].EmbeddingsToQueries.weight:
Parameter containing:
tensor([[-0.0631,  0.0060,  0.0413,  ..., -0.0608,  0.0985, -0.0271],
        [ 0.0273,  0.0964, -0.0380,  ..., -0.0497, -0.0173,  0.0229],
        [-0.0879,  0.0978, -0.0188,  ...,  0.0981,  0.0710,  0.0890],
        ...,
        [-0.0103,  0.0494, -0.0886,  ...,  0.0792, -0.0412,  0.0479],
        [ 0.0925,  0.0802,  0.0850,  ...,  0.0525, -0.0185, -0.0444],
        [-0.0391, -0.0055, -0.0904,  ..., -0.0720, -0.0473, -0.0830]],
       requires_grad=True)
TransformerInstance.TransformerEncoderInstance.EncoderOutputToKeysLayer.weight
Parameter containing:
tensor([[ 0.0413, -0.0990, -0.0773,  ...,  0.0840,  0.0890, -0.0139],
        [ 0.0795,  0.0047, -0.0117,  ..., -0.0940,  0.0053, -0.0362],
        [ 0.0993, -0.0706, -0.0795,  ...,  0.0592,  0.0644,  0.0211],
        ...,
        [-0.0119, -0.0627, -0.0526,  ..., -0.0643, -0.0991, -0.0483],
        [-0.0391,  0.0638, -0.0160,  ...,  0.0366,  0.0023,  0.0325],
        [-0.0905,  0.0021, -0.0823,  ...,  0.0172,  0.0149,  0.0779]],
       requires_grad=True)
After loading:
TransformerInstance.TransformerEncoderInstance.encoderBlocks[0].MultiHeadAttentionLayer.scaled_dot_product_attention_layers[0].EmbeddingsToQueries.weight:
Parameter containing:
tensor([[-0.0631,  0.0060,  0.0413,  ..., -0.0608,  0.0985, -0.0271],
        [ 0.0273,  0.0964, -0.0380,  ..., -0.0497, -0.0173,  0.0229],
        [-0.0879,  0.0978, -0.0188,  ...,  0.0981,  0.0710,  0.0890],
        ...,
        [-0.0103,  0.0494, -0.0886,  ...,  0.0792, -0.0412,  0.0479],
        [ 0.0925,  0.0802,  0.0850,  ...,  0.0525, -0.0185, -0.0444],
        [-0.0391, -0.0055, -0.0904,  ..., -0.0720, -0.0473, -0.0830]],
       requires_grad=True)
TransformerInstance.TransformerEncoderInstance.EncoderOutputToKeysLayer.weight
Parameter containing:
tensor([[-0.5735,  0.4942, -0.2234,  ..., -0.2754, -0.0231,  0.4324],
        [ 0.1600,  0.3717, -0.6777,  ..., -0.2678, -0.3618,  0.0521],
        [ 0.1210,  0.2996, -0.1996,  ...,  0.3147,  0.0737,  0.2909],
        ...,
        [-0.1566,  0.0528,  0.0819,  ..., -0.0080,  0.2442, -0.1649],
        [-0.4223,  0.0990, -0.1015,  ...,  0.2326,  0.0561,  0.1921],
        [-0.3537,  0.2242,  0.1864,  ..., -0.2561,  0.4154, -0.1085]],
       requires_grad=True)

As you can see from the output above, when I load the model, the TransformerInstance.TransformerEncoderInstance.EncoderOutputToKeysLayer.weight changes, but the TransformerInstance.TransformerEncoderInstance.encoderBlocks[0].MultiHeadAttentionLayer.scaled_dot_product_attention_layers[0].EmbeddingsToQueries.weight does not. I’d like to save all the weights of all the sub-layers that the Transformer class is composed of, not only of TransformerEncoderInstance and TransformerDecoderInstance (not displayed here).

How can I do this?

Thank you in advance for your help!

PyTorch will save all parameters recursively from all submodules if these are properly registered.
I don’t know how TransformerInstance or any of its internal submoduels are defined and registered, but would guess that e.g. encoderBlocks might be a plain Python list instead of the required nn.ModuleList.

3 Likes

I too have checkpointing and saving problems. I am working with a progressively growing GAN with multiple blocks (stored in a nn.ModuleList). Training works well, however, when saving, it is saving the wrong weights. Is there a way to check if PyTorch did register all my layers and subblocks properly? Because I think this might be the solution to my problem.

Could you explain what wrong weights means and how you’ve checked it?
You could iterate the names of all registered parameters via model.named_parameters() which will then list all registered params.

I wrote about it already on stack overflow and in the lightning forum. Here are the links to stackoverflow and lightning. I flagged them as solved there because I thought your answer here solved the problem, however after a full training rotation I just found that it still consists. In the lightning forum I also posted plots showing my exact problem, in general I describe the project in more detail there.

I don’t know if it is bad manner just posting links to other websites. Let me know if I should describe it here again.

Cross-posting the same question into multiple boards is potentially a huge waste of time as multiple users could debug the same issue, which is why I discourage it.

I fully understand, however I am debugging this for over 4 months now and my posts (also 1 month old already) did not got any answers, this is why I asked here again since it seems related to my problem. Also I do not know whether it is a torch, lightning or general problem. It would be nice if anyone could help or give a hint.