Dear PyTorch community,
today I encountered a bug which baffled me for quite a while. It’s related to saving the trained model weights, but I don’t know how to fix it, so that’s why I’m asking here.
I’m implementing Transformer from scratch in PyTorch. I first wrote the code for the lowest layers of the Transformer, such as Scaled Dot-Product Attention and Multi-Head Attention. Then I went on to create the Encoder and the Decoder blocks. Then I created the Transformer Encoder and Transformer Decoder classes. Finally, I created the Transformer class.
I am able to train the model to fit a small batch perfectly. After I train the model for some number of epochs, I save it with the following code (as explained in the Quickstart and Saving and Loading Models):
weights_filename = "transformer_model_trained_on_mini_train_dataset_weights_after_" + str(number_of_epochs) + "_epochs.pth"
torch.save(TransformerInstance.state_dict(), weights_filename)
Then later, I load it with:
TransformerInstance.load_state_dict(torch.load("transformer_model_trained_on_mini_train_dataset_weights_after_2048_epochs.pth"))
Here’s the problem: The saving only saves the weights of the layers specific to Transformer Encoder and Transformer Decoder. It doesn’t save the weights recursively. So what I’m saying is is that, while a layer that is within the Transformer Encoder for example will be saved, the Scaled Dot-Product Attention layer weights will not. To repeat, the Transformer contains the Transformer Encoder, the Transformer Encoder contains the Encoder block, the Encoder block contains the Multi-Head Attention layer and finally, the Multi-Head Attention layer contains the Scaled Dot-Product Attention layer. This is just a specific example, the point is that the weights don’t get saved “recursively”.
I think the following output will make the problem clear:
Before loading:
TransformerInstance.TransformerEncoderInstance.encoderBlocks[0].MultiHeadAttentionLayer.scaled_dot_product_attention_layers[0].EmbeddingsToQueries.weight:
Parameter containing:
tensor([[-0.0631, 0.0060, 0.0413, ..., -0.0608, 0.0985, -0.0271],
[ 0.0273, 0.0964, -0.0380, ..., -0.0497, -0.0173, 0.0229],
[-0.0879, 0.0978, -0.0188, ..., 0.0981, 0.0710, 0.0890],
...,
[-0.0103, 0.0494, -0.0886, ..., 0.0792, -0.0412, 0.0479],
[ 0.0925, 0.0802, 0.0850, ..., 0.0525, -0.0185, -0.0444],
[-0.0391, -0.0055, -0.0904, ..., -0.0720, -0.0473, -0.0830]],
requires_grad=True)
TransformerInstance.TransformerEncoderInstance.EncoderOutputToKeysLayer.weight
Parameter containing:
tensor([[ 0.0413, -0.0990, -0.0773, ..., 0.0840, 0.0890, -0.0139],
[ 0.0795, 0.0047, -0.0117, ..., -0.0940, 0.0053, -0.0362],
[ 0.0993, -0.0706, -0.0795, ..., 0.0592, 0.0644, 0.0211],
...,
[-0.0119, -0.0627, -0.0526, ..., -0.0643, -0.0991, -0.0483],
[-0.0391, 0.0638, -0.0160, ..., 0.0366, 0.0023, 0.0325],
[-0.0905, 0.0021, -0.0823, ..., 0.0172, 0.0149, 0.0779]],
requires_grad=True)
After loading:
TransformerInstance.TransformerEncoderInstance.encoderBlocks[0].MultiHeadAttentionLayer.scaled_dot_product_attention_layers[0].EmbeddingsToQueries.weight:
Parameter containing:
tensor([[-0.0631, 0.0060, 0.0413, ..., -0.0608, 0.0985, -0.0271],
[ 0.0273, 0.0964, -0.0380, ..., -0.0497, -0.0173, 0.0229],
[-0.0879, 0.0978, -0.0188, ..., 0.0981, 0.0710, 0.0890],
...,
[-0.0103, 0.0494, -0.0886, ..., 0.0792, -0.0412, 0.0479],
[ 0.0925, 0.0802, 0.0850, ..., 0.0525, -0.0185, -0.0444],
[-0.0391, -0.0055, -0.0904, ..., -0.0720, -0.0473, -0.0830]],
requires_grad=True)
TransformerInstance.TransformerEncoderInstance.EncoderOutputToKeysLayer.weight
Parameter containing:
tensor([[-0.5735, 0.4942, -0.2234, ..., -0.2754, -0.0231, 0.4324],
[ 0.1600, 0.3717, -0.6777, ..., -0.2678, -0.3618, 0.0521],
[ 0.1210, 0.2996, -0.1996, ..., 0.3147, 0.0737, 0.2909],
...,
[-0.1566, 0.0528, 0.0819, ..., -0.0080, 0.2442, -0.1649],
[-0.4223, 0.0990, -0.1015, ..., 0.2326, 0.0561, 0.1921],
[-0.3537, 0.2242, 0.1864, ..., -0.2561, 0.4154, -0.1085]],
requires_grad=True)
As you can see from the output above, when I load the model, the TransformerInstance.TransformerEncoderInstance.EncoderOutputToKeysLayer.weight
changes, but the TransformerInstance.TransformerEncoderInstance.encoderBlocks[0].MultiHeadAttentionLayer.scaled_dot_product_attention_layers[0].EmbeddingsToQueries.weight
does not. I’d like to save all the weights of all the sub-layers that the Transformer
class is composed of, not only of TransformerEncoderInstance
and TransformerDecoderInstance
(not displayed here).
How can I do this?
Thank you in advance for your help!