How to do transfer learning with BERT and pretrained model?

I have pretrained model for summarization, and it relies on BERT model. It is using bert-base-uncased (English), and I want to replace it with BERT model for my language. However, my model has vocabulary of 105879 words, while bert-base-uncased has 30522 words, so I’m getting following errors:

RuntimeError: Error(s) in loading state_dict for AbsSummarizer:
	size mismatch for bert.model.embeddings.word_embeddings.weight: copying a param with shape torch.Size([30522, 512]) from checkpoint, the shape in current model is torch.Size([105879, 512]).
	size mismatch for decoder.embeddings.weight: copying a param with shape torch.Size([30522, 512]) from checkpoint, the shape in current model is torch.Size([105879, 512]).
	size mismatch for generator.0.weight: copying a param with shape torch.Size([30522, 512]) from checkpoint, the shape in current model is torch.Size([105879, 512]).
	size mismatch for generator.0.bias: copying a param with shape torch.Size([30522]) from checkpoint, the shape in current model is torch.Size([105879]).

I’ve tried to use resize_token_embeddings function, but it doesn’t map words to vectors properly, so I was wondering if there is a way to replace these layers with mismatched sizes with new layers with random weights, and then maybe fine-tune model for couple of epochs? What is the right approach here? This is new to me so maybe I’m missing the right angle.

Here is the state dict of the model if it helps:



Model's state_dict:
bert.model.embeddings.word_embeddings.weight 	 torch.Size([30522, 512])
bert.model.embeddings.position_embeddings.weight 	 torch.Size([512, 512])
bert.model.embeddings.token_type_embeddings.weight 	 torch.Size([2, 512])
bert.model.embeddings.LayerNorm.weight 	 torch.Size([512])
bert.model.embeddings.LayerNorm.bias 	 torch.Size([512])
bert.model.encoder.layer.0.attention.self.query.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.0.attention.self.query.bias 	 torch.Size([512])
bert.model.encoder.layer.0.attention.self.key.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.0.attention.self.key.bias 	 torch.Size([512])
bert.model.encoder.layer.0.attention.self.value.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.0.attention.self.value.bias 	 torch.Size([512])
bert.model.encoder.layer.0.attention.output.dense.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.0.attention.output.dense.bias 	 torch.Size([512])
bert.model.encoder.layer.0.attention.output.LayerNorm.weight 	 torch.Size([512])
bert.model.encoder.layer.0.attention.output.LayerNorm.bias 	 torch.Size([512])
bert.model.encoder.layer.0.intermediate.dense.weight 	 torch.Size([2048, 512])
bert.model.encoder.layer.0.intermediate.dense.bias 	 torch.Size([2048])
bert.model.encoder.layer.0.output.dense.weight 	 torch.Size([512, 2048])
bert.model.encoder.layer.0.output.dense.bias 	 torch.Size([512])
bert.model.encoder.layer.0.output.LayerNorm.weight 	 torch.Size([512])
bert.model.encoder.layer.0.output.LayerNorm.bias 	 torch.Size([512])
bert.model.encoder.layer.1.attention.self.query.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.1.attention.self.query.bias 	 torch.Size([512])
bert.model.encoder.layer.1.attention.self.key.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.1.attention.self.key.bias 	 torch.Size([512])
bert.model.encoder.layer.1.attention.self.value.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.1.attention.self.value.bias 	 torch.Size([512])
bert.model.encoder.layer.1.attention.output.dense.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.1.attention.output.dense.bias 	 torch.Size([512])
bert.model.encoder.layer.1.attention.output.LayerNorm.weight 	 torch.Size([512])
bert.model.encoder.layer.1.attention.output.LayerNorm.bias 	 torch.Size([512])
bert.model.encoder.layer.1.intermediate.dense.weight 	 torch.Size([2048, 512])
bert.model.encoder.layer.1.intermediate.dense.bias 	 torch.Size([2048])
bert.model.encoder.layer.1.output.dense.weight 	 torch.Size([512, 2048])
bert.model.encoder.layer.1.output.dense.bias 	 torch.Size([512])
bert.model.encoder.layer.1.output.LayerNorm.weight 	 torch.Size([512])
bert.model.encoder.layer.1.output.LayerNorm.bias 	 torch.Size([512])
bert.model.encoder.layer.2.attention.self.query.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.2.attention.self.query.bias 	 torch.Size([512])
bert.model.encoder.layer.2.attention.self.key.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.2.attention.self.key.bias 	 torch.Size([512])
bert.model.encoder.layer.2.attention.self.value.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.2.attention.self.value.bias 	 torch.Size([512])
bert.model.encoder.layer.2.attention.output.dense.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.2.attention.output.dense.bias 	 torch.Size([512])
bert.model.encoder.layer.2.attention.output.LayerNorm.weight 	 torch.Size([512])
bert.model.encoder.layer.2.attention.output.LayerNorm.bias 	 torch.Size([512])
bert.model.encoder.layer.2.intermediate.dense.weight 	 torch.Size([2048, 512])
bert.model.encoder.layer.2.intermediate.dense.bias 	 torch.Size([2048])
bert.model.encoder.layer.2.output.dense.weight 	 torch.Size([512, 2048])
bert.model.encoder.layer.2.output.dense.bias 	 torch.Size([512])
bert.model.encoder.layer.2.output.LayerNorm.weight 	 torch.Size([512])
bert.model.encoder.layer.2.output.LayerNorm.bias 	 torch.Size([512])
bert.model.encoder.layer.3.attention.self.query.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.3.attention.self.query.bias 	 torch.Size([512])
bert.model.encoder.layer.3.attention.self.key.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.3.attention.self.key.bias 	 torch.Size([512])
bert.model.encoder.layer.3.attention.self.value.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.3.attention.self.value.bias 	 torch.Size([512])
bert.model.encoder.layer.3.attention.output.dense.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.3.attention.output.dense.bias 	 torch.Size([512])
bert.model.encoder.layer.3.attention.output.LayerNorm.weight 	 torch.Size([512])
bert.model.encoder.layer.3.attention.output.LayerNorm.bias 	 torch.Size([512])
bert.model.encoder.layer.3.intermediate.dense.weight 	 torch.Size([2048, 512])
bert.model.encoder.layer.3.intermediate.dense.bias 	 torch.Size([2048])
bert.model.encoder.layer.3.output.dense.weight 	 torch.Size([512, 2048])
bert.model.encoder.layer.3.output.dense.bias 	 torch.Size([512])
bert.model.encoder.layer.3.output.LayerNorm.weight 	 torch.Size([512])
bert.model.encoder.layer.3.output.LayerNorm.bias 	 torch.Size([512])
bert.model.encoder.layer.4.attention.self.query.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.4.attention.self.query.bias 	 torch.Size([512])
bert.model.encoder.layer.4.attention.self.key.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.4.attention.self.key.bias 	 torch.Size([512])
bert.model.encoder.layer.4.attention.self.value.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.4.attention.self.value.bias 	 torch.Size([512])
bert.model.encoder.layer.4.attention.output.dense.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.4.attention.output.dense.bias 	 torch.Size([512])
bert.model.encoder.layer.4.attention.output.LayerNorm.weight 	 torch.Size([512])
bert.model.encoder.layer.4.attention.output.LayerNorm.bias 	 torch.Size([512])
bert.model.encoder.layer.4.intermediate.dense.weight 	 torch.Size([2048, 512])
bert.model.encoder.layer.4.intermediate.dense.bias 	 torch.Size([2048])
bert.model.encoder.layer.4.output.dense.weight 	 torch.Size([512, 2048])
bert.model.encoder.layer.4.output.dense.bias 	 torch.Size([512])
bert.model.encoder.layer.4.output.LayerNorm.weight 	 torch.Size([512])
bert.model.encoder.layer.4.output.LayerNorm.bias 	 torch.Size([512])
bert.model.encoder.layer.5.attention.self.query.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.5.attention.self.query.bias 	 torch.Size([512])
bert.model.encoder.layer.5.attention.self.key.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.5.attention.self.key.bias 	 torch.Size([512])
bert.model.encoder.layer.5.attention.self.value.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.5.attention.self.value.bias 	 torch.Size([512])
bert.model.encoder.layer.5.attention.output.dense.weight 	 torch.Size([512, 512])
bert.model.encoder.layer.5.attention.output.dense.bias 	 torch.Size([512])
bert.model.encoder.layer.5.attention.output.LayerNorm.weight 	 torch.Size([512])
bert.model.encoder.layer.5.attention.output.LayerNorm.bias 	 torch.Size([512])
bert.model.encoder.layer.5.intermediate.dense.weight 	 torch.Size([2048, 512])
bert.model.encoder.layer.5.intermediate.dense.bias 	 torch.Size([2048])
bert.model.encoder.layer.5.output.dense.weight 	 torch.Size([512, 2048])
bert.model.encoder.layer.5.output.dense.bias 	 torch.Size([512])
bert.model.encoder.layer.5.output.LayerNorm.weight 	 torch.Size([512])
bert.model.encoder.layer.5.output.LayerNorm.bias 	 torch.Size([512])
bert.model.pooler.dense.weight 	 torch.Size([512, 512])
bert.model.pooler.dense.bias 	 torch.Size([512])
decoder.embeddings.weight 	 torch.Size([30522, 512])
decoder.pos_emb.pe 	 torch.Size([1, 5000, 512])
decoder.transformer_layers.0.mask 	 torch.Size([1, 5000, 5000])
decoder.transformer_layers.0.self_attn.linear_keys.weight 	 torch.Size([512, 512])
decoder.transformer_layers.0.self_attn.linear_keys.bias 	 torch.Size([512])
decoder.transformer_layers.0.self_attn.linear_values.weight 	 torch.Size([512, 512])
decoder.transformer_layers.0.self_attn.linear_values.bias 	 torch.Size([512])
decoder.transformer_layers.0.self_attn.linear_query.weight 	 torch.Size([512, 512])
decoder.transformer_layers.0.self_attn.linear_query.bias 	 torch.Size([512])
decoder.transformer_layers.0.self_attn.final_linear.weight 	 torch.Size([512, 512])
decoder.transformer_layers.0.self_attn.final_linear.bias 	 torch.Size([512])
decoder.transformer_layers.0.context_attn.linear_keys.weight 	 torch.Size([512, 512])
decoder.transformer_layers.0.context_attn.linear_keys.bias 	 torch.Size([512])
decoder.transformer_layers.0.context_attn.linear_values.weight 	 torch.Size([512, 512])
decoder.transformer_layers.0.context_attn.linear_values.bias 	 torch.Size([512])
decoder.transformer_layers.0.context_attn.linear_query.weight 	 torch.Size([512, 512])
decoder.transformer_layers.0.context_attn.linear_query.bias 	 torch.Size([512])
decoder.transformer_layers.0.context_attn.final_linear.weight 	 torch.Size([512, 512])
decoder.transformer_layers.0.context_attn.final_linear.bias 	 torch.Size([512])
decoder.transformer_layers.0.feed_forward.w_1.weight 	 torch.Size([2048, 512])
decoder.transformer_layers.0.feed_forward.w_1.bias 	 torch.Size([2048])
decoder.transformer_layers.0.feed_forward.w_2.weight 	 torch.Size([512, 2048])
decoder.transformer_layers.0.feed_forward.w_2.bias 	 torch.Size([512])
decoder.transformer_layers.0.feed_forward.layer_norm.weight 	 torch.Size([512])
decoder.transformer_layers.0.feed_forward.layer_norm.bias 	 torch.Size([512])
decoder.transformer_layers.0.layer_norm_1.weight 	 torch.Size([512])
decoder.transformer_layers.0.layer_norm_1.bias 	 torch.Size([512])
decoder.transformer_layers.0.layer_norm_2.weight 	 torch.Size([512])
decoder.transformer_layers.0.layer_norm_2.bias 	 torch.Size([512])
decoder.transformer_layers.1.mask 	 torch.Size([1, 5000, 5000])
decoder.transformer_layers.1.self_attn.linear_keys.weight 	 torch.Size([512, 512])
decoder.transformer_layers.1.self_attn.linear_keys.bias 	 torch.Size([512])
decoder.transformer_layers.1.self_attn.linear_values.weight 	 torch.Size([512, 512])
decoder.transformer_layers.1.self_attn.linear_values.bias 	 torch.Size([512])
decoder.transformer_layers.1.self_attn.linear_query.weight 	 torch.Size([512, 512])
decoder.transformer_layers.1.self_attn.linear_query.bias 	 torch.Size([512])
decoder.transformer_layers.1.self_attn.final_linear.weight 	 torch.Size([512, 512])
decoder.transformer_layers.1.self_attn.final_linear.bias 	 torch.Size([512])
decoder.transformer_layers.1.context_attn.linear_keys.weight 	 torch.Size([512, 512])
decoder.transformer_layers.1.context_attn.linear_keys.bias 	 torch.Size([512])
decoder.transformer_layers.1.context_attn.linear_values.weight 	 torch.Size([512, 512])
decoder.transformer_layers.1.context_attn.linear_values.bias 	 torch.Size([512])
decoder.transformer_layers.1.context_attn.linear_query.weight 	 torch.Size([512, 512])
decoder.transformer_layers.1.context_attn.linear_query.bias 	 torch.Size([512])
decoder.transformer_layers.1.context_attn.final_linear.weight 	 torch.Size([512, 512])
decoder.transformer_layers.1.context_attn.final_linear.bias 	 torch.Size([512])
decoder.transformer_layers.1.feed_forward.w_1.weight 	 torch.Size([2048, 512])
decoder.transformer_layers.1.feed_forward.w_1.bias 	 torch.Size([2048])
decoder.transformer_layers.1.feed_forward.w_2.weight 	 torch.Size([512, 2048])
decoder.transformer_layers.1.feed_forward.w_2.bias 	 torch.Size([512])
decoder.transformer_layers.1.feed_forward.layer_norm.weight 	 torch.Size([512])
decoder.transformer_layers.1.feed_forward.layer_norm.bias 	 torch.Size([512])
decoder.transformer_layers.1.layer_norm_1.weight 	 torch.Size([512])
decoder.transformer_layers.1.layer_norm_1.bias 	 torch.Size([512])
decoder.transformer_layers.1.layer_norm_2.weight 	 torch.Size([512])
decoder.transformer_layers.1.layer_norm_2.bias 	 torch.Size([512])
decoder.transformer_layers.2.mask 	 torch.Size([1, 5000, 5000])
decoder.transformer_layers.2.self_attn.linear_keys.weight 	 torch.Size([512, 512])
decoder.transformer_layers.2.self_attn.linear_keys.bias 	 torch.Size([512])
decoder.transformer_layers.2.self_attn.linear_values.weight 	 torch.Size([512, 512])
decoder.transformer_layers.2.self_attn.linear_values.bias 	 torch.Size([512])
decoder.transformer_layers.2.self_attn.linear_query.weight 	 torch.Size([512, 512])
decoder.transformer_layers.2.self_attn.linear_query.bias 	 torch.Size([512])
decoder.transformer_layers.2.self_attn.final_linear.weight 	 torch.Size([512, 512])
decoder.transformer_layers.2.self_attn.final_linear.bias 	 torch.Size([512])
decoder.transformer_layers.2.context_attn.linear_keys.weight 	 torch.Size([512, 512])
decoder.transformer_layers.2.context_attn.linear_keys.bias 	 torch.Size([512])
decoder.transformer_layers.2.context_attn.linear_values.weight 	 torch.Size([512, 512])
decoder.transformer_layers.2.context_attn.linear_values.bias 	 torch.Size([512])
decoder.transformer_layers.2.context_attn.linear_query.weight 	 torch.Size([512, 512])
decoder.transformer_layers.2.context_attn.linear_query.bias 	 torch.Size([512])
decoder.transformer_layers.2.context_attn.final_linear.weight 	 torch.Size([512, 512])
decoder.transformer_layers.2.context_attn.final_linear.bias 	 torch.Size([512])
decoder.transformer_layers.2.feed_forward.w_1.weight 	 torch.Size([2048, 512])
decoder.transformer_layers.2.feed_forward.w_1.bias 	 torch.Size([2048])
decoder.transformer_layers.2.feed_forward.w_2.weight 	 torch.Size([512, 2048])
decoder.transformer_layers.2.feed_forward.w_2.bias 	 torch.Size([512])
decoder.transformer_layers.2.feed_forward.layer_norm.weight 	 torch.Size([512])
decoder.transformer_layers.2.feed_forward.layer_norm.bias 	 torch.Size([512])
decoder.transformer_layers.2.layer_norm_1.weight 	 torch.Size([512])
decoder.transformer_layers.2.layer_norm_1.bias 	 torch.Size([512])
decoder.transformer_layers.2.layer_norm_2.weight 	 torch.Size([512])
decoder.transformer_layers.2.layer_norm_2.bias 	 torch.Size([512])
decoder.transformer_layers.3.mask 	 torch.Size([1, 5000, 5000])
decoder.transformer_layers.3.self_attn.linear_keys.weight 	 torch.Size([512, 512])
decoder.transformer_layers.3.self_attn.linear_keys.bias 	 torch.Size([512])
decoder.transformer_layers.3.self_attn.linear_values.weight 	 torch.Size([512, 512])
decoder.transformer_layers.3.self_attn.linear_values.bias 	 torch.Size([512])
decoder.transformer_layers.3.self_attn.linear_query.weight 	 torch.Size([512, 512])
decoder.transformer_layers.3.self_attn.linear_query.bias 	 torch.Size([512])
decoder.transformer_layers.3.self_attn.final_linear.weight 	 torch.Size([512, 512])
decoder.transformer_layers.3.self_attn.final_linear.bias 	 torch.Size([512])
decoder.transformer_layers.3.context_attn.linear_keys.weight 	 torch.Size([512, 512])
decoder.transformer_layers.3.context_attn.linear_keys.bias 	 torch.Size([512])
decoder.transformer_layers.3.context_attn.linear_values.weight 	 torch.Size([512, 512])
decoder.transformer_layers.3.context_attn.linear_values.bias 	 torch.Size([512])
decoder.transformer_layers.3.context_attn.linear_query.weight 	 torch.Size([512, 512])
decoder.transformer_layers.3.context_attn.linear_query.bias 	 torch.Size([512])
decoder.transformer_layers.3.context_attn.final_linear.weight 	 torch.Size([512, 512])
decoder.transformer_layers.3.context_attn.final_linear.bias 	 torch.Size([512])
decoder.transformer_layers.3.feed_forward.w_1.weight 	 torch.Size([2048, 512])
decoder.transformer_layers.3.feed_forward.w_1.bias 	 torch.Size([2048])
decoder.transformer_layers.3.feed_forward.w_2.weight 	 torch.Size([512, 2048])
decoder.transformer_layers.3.feed_forward.w_2.bias 	 torch.Size([512])
decoder.transformer_layers.3.feed_forward.layer_norm.weight 	 torch.Size([512])
decoder.transformer_layers.3.feed_forward.layer_norm.bias 	 torch.Size([512])
decoder.transformer_layers.3.layer_norm_1.weight 	 torch.Size([512])
decoder.transformer_layers.3.layer_norm_1.bias 	 torch.Size([512])
decoder.transformer_layers.3.layer_norm_2.weight 	 torch.Size([512])
decoder.transformer_layers.3.layer_norm_2.bias 	 torch.Size([512])
decoder.transformer_layers.4.mask 	 torch.Size([1, 5000, 5000])
decoder.transformer_layers.4.self_attn.linear_keys.weight 	 torch.Size([512, 512])
decoder.transformer_layers.4.self_attn.linear_keys.bias 	 torch.Size([512])
decoder.transformer_layers.4.self_attn.linear_values.weight 	 torch.Size([512, 512])
decoder.transformer_layers.4.self_attn.linear_values.bias 	 torch.Size([512])
decoder.transformer_layers.4.self_attn.linear_query.weight 	 torch.Size([512, 512])
decoder.transformer_layers.4.self_attn.linear_query.bias 	 torch.Size([512])
decoder.transformer_layers.4.self_attn.final_linear.weight 	 torch.Size([512, 512])
decoder.transformer_layers.4.self_attn.final_linear.bias 	 torch.Size([512])
decoder.transformer_layers.4.context_attn.linear_keys.weight 	 torch.Size([512, 512])
decoder.transformer_layers.4.context_attn.linear_keys.bias 	 torch.Size([512])
decoder.transformer_layers.4.context_attn.linear_values.weight 	 torch.Size([512, 512])
decoder.transformer_layers.4.context_attn.linear_values.bias 	 torch.Size([512])
decoder.transformer_layers.4.context_attn.linear_query.weight 	 torch.Size([512, 512])
decoder.transformer_layers.4.context_attn.linear_query.bias 	 torch.Size([512])
decoder.transformer_layers.4.context_attn.final_linear.weight 	 torch.Size([512, 512])
decoder.transformer_layers.4.context_attn.final_linear.bias 	 torch.Size([512])
decoder.transformer_layers.4.feed_forward.w_1.weight 	 torch.Size([2048, 512])
decoder.transformer_layers.4.feed_forward.w_1.bias 	 torch.Size([2048])
decoder.transformer_layers.4.feed_forward.w_2.weight 	 torch.Size([512, 2048])
decoder.transformer_layers.4.feed_forward.w_2.bias 	 torch.Size([512])
decoder.transformer_layers.4.feed_forward.layer_norm.weight 	 torch.Size([512])
decoder.transformer_layers.4.feed_forward.layer_norm.bias 	 torch.Size([512])
decoder.transformer_layers.4.layer_norm_1.weight 	 torch.Size([512])
decoder.transformer_layers.4.layer_norm_1.bias 	 torch.Size([512])
decoder.transformer_layers.4.layer_norm_2.weight 	 torch.Size([512])
decoder.transformer_layers.4.layer_norm_2.bias 	 torch.Size([512])
decoder.transformer_layers.5.mask 	 torch.Size([1, 5000, 5000])
decoder.transformer_layers.5.self_attn.linear_keys.weight 	 torch.Size([512, 512])
decoder.transformer_layers.5.self_attn.linear_keys.bias 	 torch.Size([512])
decoder.transformer_layers.5.self_attn.linear_values.weight 	 torch.Size([512, 512])
decoder.transformer_layers.5.self_attn.linear_values.bias 	 torch.Size([512])
decoder.transformer_layers.5.self_attn.linear_query.weight 	 torch.Size([512, 512])
decoder.transformer_layers.5.self_attn.linear_query.bias 	 torch.Size([512])
decoder.transformer_layers.5.self_attn.final_linear.weight 	 torch.Size([512, 512])
decoder.transformer_layers.5.self_attn.final_linear.bias 	 torch.Size([512])
decoder.transformer_layers.5.context_attn.linear_keys.weight 	 torch.Size([512, 512])
decoder.transformer_layers.5.context_attn.linear_keys.bias 	 torch.Size([512])
decoder.transformer_layers.5.context_attn.linear_values.weight 	 torch.Size([512, 512])
decoder.transformer_layers.5.context_attn.linear_values.bias 	 torch.Size([512])
decoder.transformer_layers.5.context_attn.linear_query.weight 	 torch.Size([512, 512])
decoder.transformer_layers.5.context_attn.linear_query.bias 	 torch.Size([512])
decoder.transformer_layers.5.context_attn.final_linear.weight 	 torch.Size([512, 512])
decoder.transformer_layers.5.context_attn.final_linear.bias 	 torch.Size([512])
decoder.transformer_layers.5.feed_forward.w_1.weight 	 torch.Size([2048, 512])
decoder.transformer_layers.5.feed_forward.w_1.bias 	 torch.Size([2048])
decoder.transformer_layers.5.feed_forward.w_2.weight 	 torch.Size([512, 2048])
decoder.transformer_layers.5.feed_forward.w_2.bias 	 torch.Size([512])
decoder.transformer_layers.5.feed_forward.layer_norm.weight 	 torch.Size([512])
decoder.transformer_layers.5.feed_forward.layer_norm.bias 	 torch.Size([512])
decoder.transformer_layers.5.layer_norm_1.weight 	 torch.Size([512])
decoder.transformer_layers.5.layer_norm_1.bias 	 torch.Size([512])
decoder.transformer_layers.5.layer_norm_2.weight 	 torch.Size([512])
decoder.transformer_layers.5.layer_norm_2.bias 	 torch.Size([512])
decoder.layer_norm.weight 	 torch.Size([512])
decoder.layer_norm.bias 	 torch.Size([512])
generator.0.weight 	 torch.Size([30522, 512])
generator.0.bias 	 torch.Size([30522])

I think your approach of initializing the embedding layers randomly and retrain them makes sense.
Could you try to use the strict=False argument when loading the state_dict via:

model.load_state_dict(state_dict, strict=False)

This should skip the mismatched layers.

Interested in a similar thing as well. Does strict=False completely ignore mismatching layers, or partially fill them?

I was wrong in the assumption that strict=False would ignore the shape mismatches.
While it will skip all missing and unexpected keys, shape mismatches would still trigger an error.

Ah, thank you i’ll start a different topic for clarity