How can I make PyTorch save all the weights from all the sub-layers the model is composed of?

Dear PyTorch community,

today I encountered a bug which baffled me for quite a while. It’s related to saving the trained model weights, but I don’t know how to fix it, so that’s why I’m asking here.

I’m implementing Transformer from scratch in PyTorch. I first wrote the code for the lowest layers of the Transformer, such as Scaled Dot-Product Attention and Multi-Head Attention. Then I went on to create the Encoder and the Decoder blocks. Then I created the Transformer Encoder and Transformer Decoder classes. Finally, I created the Transformer class.

I am able to train the model to fit a small batch perfectly. After I train the model for some number of epochs, I save it with the following code (as explained in the Quickstart and Saving and Loading Models):

weights_filename = "transformer_model_trained_on_mini_train_dataset_weights_after_" + str(number_of_epochs) + "_epochs.pth"
torch.save(TransformerInstance.state_dict(), weights_filename)

Then later, I load it with:

TransformerInstance.load_state_dict(torch.load("transformer_model_trained_on_mini_train_dataset_weights_after_2048_epochs.pth"))

Here’s the problem: The saving only saves the weights of the layers specific to Transformer Encoder and Transformer Decoder. It doesn’t save the weights recursively. So what I’m saying is is that, while a layer that is within the Transformer Encoder for example will be saved, the Scaled Dot-Product Attention layer weights will not. To repeat, the Transformer contains the Transformer Encoder, the Transformer Encoder contains the Encoder block, the Encoder block contains the Multi-Head Attention layer and finally, the Multi-Head Attention layer contains the Scaled Dot-Product Attention layer. This is just a specific example, the point is that the weights don’t get saved “recursively”.

I think the following output will make the problem clear:

Before loading:
TransformerInstance.TransformerEncoderInstance.encoderBlocks[0].MultiHeadAttentionLayer.scaled_dot_product_attention_layers[0].EmbeddingsToQueries.weight:
Parameter containing:
tensor([[-0.0631,  0.0060,  0.0413,  ..., -0.0608,  0.0985, -0.0271],
        [ 0.0273,  0.0964, -0.0380,  ..., -0.0497, -0.0173,  0.0229],
        [-0.0879,  0.0978, -0.0188,  ...,  0.0981,  0.0710,  0.0890],
        ...,
        [-0.0103,  0.0494, -0.0886,  ...,  0.0792, -0.0412,  0.0479],
        [ 0.0925,  0.0802,  0.0850,  ...,  0.0525, -0.0185, -0.0444],
        [-0.0391, -0.0055, -0.0904,  ..., -0.0720, -0.0473, -0.0830]],
       requires_grad=True)
TransformerInstance.TransformerEncoderInstance.EncoderOutputToKeysLayer.weight
Parameter containing:
tensor([[ 0.0413, -0.0990, -0.0773,  ...,  0.0840,  0.0890, -0.0139],
        [ 0.0795,  0.0047, -0.0117,  ..., -0.0940,  0.0053, -0.0362],
        [ 0.0993, -0.0706, -0.0795,  ...,  0.0592,  0.0644,  0.0211],
        ...,
        [-0.0119, -0.0627, -0.0526,  ..., -0.0643, -0.0991, -0.0483],
        [-0.0391,  0.0638, -0.0160,  ...,  0.0366,  0.0023,  0.0325],
        [-0.0905,  0.0021, -0.0823,  ...,  0.0172,  0.0149,  0.0779]],
       requires_grad=True)
After loading:
TransformerInstance.TransformerEncoderInstance.encoderBlocks[0].MultiHeadAttentionLayer.scaled_dot_product_attention_layers[0].EmbeddingsToQueries.weight:
Parameter containing:
tensor([[-0.0631,  0.0060,  0.0413,  ..., -0.0608,  0.0985, -0.0271],
        [ 0.0273,  0.0964, -0.0380,  ..., -0.0497, -0.0173,  0.0229],
        [-0.0879,  0.0978, -0.0188,  ...,  0.0981,  0.0710,  0.0890],
        ...,
        [-0.0103,  0.0494, -0.0886,  ...,  0.0792, -0.0412,  0.0479],
        [ 0.0925,  0.0802,  0.0850,  ...,  0.0525, -0.0185, -0.0444],
        [-0.0391, -0.0055, -0.0904,  ..., -0.0720, -0.0473, -0.0830]],
       requires_grad=True)
TransformerInstance.TransformerEncoderInstance.EncoderOutputToKeysLayer.weight
Parameter containing:
tensor([[-0.5735,  0.4942, -0.2234,  ..., -0.2754, -0.0231,  0.4324],
        [ 0.1600,  0.3717, -0.6777,  ..., -0.2678, -0.3618,  0.0521],
        [ 0.1210,  0.2996, -0.1996,  ...,  0.3147,  0.0737,  0.2909],
        ...,
        [-0.1566,  0.0528,  0.0819,  ..., -0.0080,  0.2442, -0.1649],
        [-0.4223,  0.0990, -0.1015,  ...,  0.2326,  0.0561,  0.1921],
        [-0.3537,  0.2242,  0.1864,  ..., -0.2561,  0.4154, -0.1085]],
       requires_grad=True)

As you can see from the output above, when I load the model, the TransformerInstance.TransformerEncoderInstance.EncoderOutputToKeysLayer.weight changes, but the TransformerInstance.TransformerEncoderInstance.encoderBlocks[0].MultiHeadAttentionLayer.scaled_dot_product_attention_layers[0].EmbeddingsToQueries.weight does not. I’d like to save all the weights of all the sub-layers that the Transformer class is composed of, not only of TransformerEncoderInstance and TransformerDecoderInstance (not displayed here).

How can I do this?

Thank you in advance for your help!

PyTorch will save all parameters recursively from all submodules if these are properly registered.
I don’t know how TransformerInstance or any of its internal submoduels are defined and registered, but would guess that e.g. encoderBlocks might be a plain Python list instead of the required nn.ModuleList.

2 Likes