I have pretrained model for summarization, and it relies on BERT model. It is using bert-base-uncased (English), and I want to replace it with BERT model for my language. However, my model has vocabulary of 105879 words, while bert-base-uncased has 30522 words, so I’m getting following errors:
RuntimeError: Error(s) in loading state_dict for AbsSummarizer:
size mismatch for bert.model.embeddings.word_embeddings.weight: copying a param with shape torch.Size([30522, 512]) from checkpoint, the shape in current model is torch.Size([105879, 512]).
size mismatch for decoder.embeddings.weight: copying a param with shape torch.Size([30522, 512]) from checkpoint, the shape in current model is torch.Size([105879, 512]).
size mismatch for generator.0.weight: copying a param with shape torch.Size([30522, 512]) from checkpoint, the shape in current model is torch.Size([105879, 512]).
size mismatch for generator.0.bias: copying a param with shape torch.Size([30522]) from checkpoint, the shape in current model is torch.Size([105879]).
I’ve tried to use resize_token_embeddings function, but it doesn’t map words to vectors properly, so I was wondering if there is a way to replace these layers with mismatched sizes with new layers with random weights, and then maybe fine-tune model for couple of epochs? What is the right approach here? This is new to me so maybe I’m missing the right angle.
Here is the state dict of the model if it helps:
Model's state_dict:
bert.model.embeddings.word_embeddings.weight torch.Size([30522, 512])
bert.model.embeddings.position_embeddings.weight torch.Size([512, 512])
bert.model.embeddings.token_type_embeddings.weight torch.Size([2, 512])
bert.model.embeddings.LayerNorm.weight torch.Size([512])
bert.model.embeddings.LayerNorm.bias torch.Size([512])
bert.model.encoder.layer.0.attention.self.query.weight torch.Size([512, 512])
bert.model.encoder.layer.0.attention.self.query.bias torch.Size([512])
bert.model.encoder.layer.0.attention.self.key.weight torch.Size([512, 512])
bert.model.encoder.layer.0.attention.self.key.bias torch.Size([512])
bert.model.encoder.layer.0.attention.self.value.weight torch.Size([512, 512])
bert.model.encoder.layer.0.attention.self.value.bias torch.Size([512])
bert.model.encoder.layer.0.attention.output.dense.weight torch.Size([512, 512])
bert.model.encoder.layer.0.attention.output.dense.bias torch.Size([512])
bert.model.encoder.layer.0.attention.output.LayerNorm.weight torch.Size([512])
bert.model.encoder.layer.0.attention.output.LayerNorm.bias torch.Size([512])
bert.model.encoder.layer.0.intermediate.dense.weight torch.Size([2048, 512])
bert.model.encoder.layer.0.intermediate.dense.bias torch.Size([2048])
bert.model.encoder.layer.0.output.dense.weight torch.Size([512, 2048])
bert.model.encoder.layer.0.output.dense.bias torch.Size([512])
bert.model.encoder.layer.0.output.LayerNorm.weight torch.Size([512])
bert.model.encoder.layer.0.output.LayerNorm.bias torch.Size([512])
bert.model.encoder.layer.1.attention.self.query.weight torch.Size([512, 512])
bert.model.encoder.layer.1.attention.self.query.bias torch.Size([512])
bert.model.encoder.layer.1.attention.self.key.weight torch.Size([512, 512])
bert.model.encoder.layer.1.attention.self.key.bias torch.Size([512])
bert.model.encoder.layer.1.attention.self.value.weight torch.Size([512, 512])
bert.model.encoder.layer.1.attention.self.value.bias torch.Size([512])
bert.model.encoder.layer.1.attention.output.dense.weight torch.Size([512, 512])
bert.model.encoder.layer.1.attention.output.dense.bias torch.Size([512])
bert.model.encoder.layer.1.attention.output.LayerNorm.weight torch.Size([512])
bert.model.encoder.layer.1.attention.output.LayerNorm.bias torch.Size([512])
bert.model.encoder.layer.1.intermediate.dense.weight torch.Size([2048, 512])
bert.model.encoder.layer.1.intermediate.dense.bias torch.Size([2048])
bert.model.encoder.layer.1.output.dense.weight torch.Size([512, 2048])
bert.model.encoder.layer.1.output.dense.bias torch.Size([512])
bert.model.encoder.layer.1.output.LayerNorm.weight torch.Size([512])
bert.model.encoder.layer.1.output.LayerNorm.bias torch.Size([512])
bert.model.encoder.layer.2.attention.self.query.weight torch.Size([512, 512])
bert.model.encoder.layer.2.attention.self.query.bias torch.Size([512])
bert.model.encoder.layer.2.attention.self.key.weight torch.Size([512, 512])
bert.model.encoder.layer.2.attention.self.key.bias torch.Size([512])
bert.model.encoder.layer.2.attention.self.value.weight torch.Size([512, 512])
bert.model.encoder.layer.2.attention.self.value.bias torch.Size([512])
bert.model.encoder.layer.2.attention.output.dense.weight torch.Size([512, 512])
bert.model.encoder.layer.2.attention.output.dense.bias torch.Size([512])
bert.model.encoder.layer.2.attention.output.LayerNorm.weight torch.Size([512])
bert.model.encoder.layer.2.attention.output.LayerNorm.bias torch.Size([512])
bert.model.encoder.layer.2.intermediate.dense.weight torch.Size([2048, 512])
bert.model.encoder.layer.2.intermediate.dense.bias torch.Size([2048])
bert.model.encoder.layer.2.output.dense.weight torch.Size([512, 2048])
bert.model.encoder.layer.2.output.dense.bias torch.Size([512])
bert.model.encoder.layer.2.output.LayerNorm.weight torch.Size([512])
bert.model.encoder.layer.2.output.LayerNorm.bias torch.Size([512])
bert.model.encoder.layer.3.attention.self.query.weight torch.Size([512, 512])
bert.model.encoder.layer.3.attention.self.query.bias torch.Size([512])
bert.model.encoder.layer.3.attention.self.key.weight torch.Size([512, 512])
bert.model.encoder.layer.3.attention.self.key.bias torch.Size([512])
bert.model.encoder.layer.3.attention.self.value.weight torch.Size([512, 512])
bert.model.encoder.layer.3.attention.self.value.bias torch.Size([512])
bert.model.encoder.layer.3.attention.output.dense.weight torch.Size([512, 512])
bert.model.encoder.layer.3.attention.output.dense.bias torch.Size([512])
bert.model.encoder.layer.3.attention.output.LayerNorm.weight torch.Size([512])
bert.model.encoder.layer.3.attention.output.LayerNorm.bias torch.Size([512])
bert.model.encoder.layer.3.intermediate.dense.weight torch.Size([2048, 512])
bert.model.encoder.layer.3.intermediate.dense.bias torch.Size([2048])
bert.model.encoder.layer.3.output.dense.weight torch.Size([512, 2048])
bert.model.encoder.layer.3.output.dense.bias torch.Size([512])
bert.model.encoder.layer.3.output.LayerNorm.weight torch.Size([512])
bert.model.encoder.layer.3.output.LayerNorm.bias torch.Size([512])
bert.model.encoder.layer.4.attention.self.query.weight torch.Size([512, 512])
bert.model.encoder.layer.4.attention.self.query.bias torch.Size([512])
bert.model.encoder.layer.4.attention.self.key.weight torch.Size([512, 512])
bert.model.encoder.layer.4.attention.self.key.bias torch.Size([512])
bert.model.encoder.layer.4.attention.self.value.weight torch.Size([512, 512])
bert.model.encoder.layer.4.attention.self.value.bias torch.Size([512])
bert.model.encoder.layer.4.attention.output.dense.weight torch.Size([512, 512])
bert.model.encoder.layer.4.attention.output.dense.bias torch.Size([512])
bert.model.encoder.layer.4.attention.output.LayerNorm.weight torch.Size([512])
bert.model.encoder.layer.4.attention.output.LayerNorm.bias torch.Size([512])
bert.model.encoder.layer.4.intermediate.dense.weight torch.Size([2048, 512])
bert.model.encoder.layer.4.intermediate.dense.bias torch.Size([2048])
bert.model.encoder.layer.4.output.dense.weight torch.Size([512, 2048])
bert.model.encoder.layer.4.output.dense.bias torch.Size([512])
bert.model.encoder.layer.4.output.LayerNorm.weight torch.Size([512])
bert.model.encoder.layer.4.output.LayerNorm.bias torch.Size([512])
bert.model.encoder.layer.5.attention.self.query.weight torch.Size([512, 512])
bert.model.encoder.layer.5.attention.self.query.bias torch.Size([512])
bert.model.encoder.layer.5.attention.self.key.weight torch.Size([512, 512])
bert.model.encoder.layer.5.attention.self.key.bias torch.Size([512])
bert.model.encoder.layer.5.attention.self.value.weight torch.Size([512, 512])
bert.model.encoder.layer.5.attention.self.value.bias torch.Size([512])
bert.model.encoder.layer.5.attention.output.dense.weight torch.Size([512, 512])
bert.model.encoder.layer.5.attention.output.dense.bias torch.Size([512])
bert.model.encoder.layer.5.attention.output.LayerNorm.weight torch.Size([512])
bert.model.encoder.layer.5.attention.output.LayerNorm.bias torch.Size([512])
bert.model.encoder.layer.5.intermediate.dense.weight torch.Size([2048, 512])
bert.model.encoder.layer.5.intermediate.dense.bias torch.Size([2048])
bert.model.encoder.layer.5.output.dense.weight torch.Size([512, 2048])
bert.model.encoder.layer.5.output.dense.bias torch.Size([512])
bert.model.encoder.layer.5.output.LayerNorm.weight torch.Size([512])
bert.model.encoder.layer.5.output.LayerNorm.bias torch.Size([512])
bert.model.pooler.dense.weight torch.Size([512, 512])
bert.model.pooler.dense.bias torch.Size([512])
decoder.embeddings.weight torch.Size([30522, 512])
decoder.pos_emb.pe torch.Size([1, 5000, 512])
decoder.transformer_layers.0.mask torch.Size([1, 5000, 5000])
decoder.transformer_layers.0.self_attn.linear_keys.weight torch.Size([512, 512])
decoder.transformer_layers.0.self_attn.linear_keys.bias torch.Size([512])
decoder.transformer_layers.0.self_attn.linear_values.weight torch.Size([512, 512])
decoder.transformer_layers.0.self_attn.linear_values.bias torch.Size([512])
decoder.transformer_layers.0.self_attn.linear_query.weight torch.Size([512, 512])
decoder.transformer_layers.0.self_attn.linear_query.bias torch.Size([512])
decoder.transformer_layers.0.self_attn.final_linear.weight torch.Size([512, 512])
decoder.transformer_layers.0.self_attn.final_linear.bias torch.Size([512])
decoder.transformer_layers.0.context_attn.linear_keys.weight torch.Size([512, 512])
decoder.transformer_layers.0.context_attn.linear_keys.bias torch.Size([512])
decoder.transformer_layers.0.context_attn.linear_values.weight torch.Size([512, 512])
decoder.transformer_layers.0.context_attn.linear_values.bias torch.Size([512])
decoder.transformer_layers.0.context_attn.linear_query.weight torch.Size([512, 512])
decoder.transformer_layers.0.context_attn.linear_query.bias torch.Size([512])
decoder.transformer_layers.0.context_attn.final_linear.weight torch.Size([512, 512])
decoder.transformer_layers.0.context_attn.final_linear.bias torch.Size([512])
decoder.transformer_layers.0.feed_forward.w_1.weight torch.Size([2048, 512])
decoder.transformer_layers.0.feed_forward.w_1.bias torch.Size([2048])
decoder.transformer_layers.0.feed_forward.w_2.weight torch.Size([512, 2048])
decoder.transformer_layers.0.feed_forward.w_2.bias torch.Size([512])
decoder.transformer_layers.0.feed_forward.layer_norm.weight torch.Size([512])
decoder.transformer_layers.0.feed_forward.layer_norm.bias torch.Size([512])
decoder.transformer_layers.0.layer_norm_1.weight torch.Size([512])
decoder.transformer_layers.0.layer_norm_1.bias torch.Size([512])
decoder.transformer_layers.0.layer_norm_2.weight torch.Size([512])
decoder.transformer_layers.0.layer_norm_2.bias torch.Size([512])
decoder.transformer_layers.1.mask torch.Size([1, 5000, 5000])
decoder.transformer_layers.1.self_attn.linear_keys.weight torch.Size([512, 512])
decoder.transformer_layers.1.self_attn.linear_keys.bias torch.Size([512])
decoder.transformer_layers.1.self_attn.linear_values.weight torch.Size([512, 512])
decoder.transformer_layers.1.self_attn.linear_values.bias torch.Size([512])
decoder.transformer_layers.1.self_attn.linear_query.weight torch.Size([512, 512])
decoder.transformer_layers.1.self_attn.linear_query.bias torch.Size([512])
decoder.transformer_layers.1.self_attn.final_linear.weight torch.Size([512, 512])
decoder.transformer_layers.1.self_attn.final_linear.bias torch.Size([512])
decoder.transformer_layers.1.context_attn.linear_keys.weight torch.Size([512, 512])
decoder.transformer_layers.1.context_attn.linear_keys.bias torch.Size([512])
decoder.transformer_layers.1.context_attn.linear_values.weight torch.Size([512, 512])
decoder.transformer_layers.1.context_attn.linear_values.bias torch.Size([512])
decoder.transformer_layers.1.context_attn.linear_query.weight torch.Size([512, 512])
decoder.transformer_layers.1.context_attn.linear_query.bias torch.Size([512])
decoder.transformer_layers.1.context_attn.final_linear.weight torch.Size([512, 512])
decoder.transformer_layers.1.context_attn.final_linear.bias torch.Size([512])
decoder.transformer_layers.1.feed_forward.w_1.weight torch.Size([2048, 512])
decoder.transformer_layers.1.feed_forward.w_1.bias torch.Size([2048])
decoder.transformer_layers.1.feed_forward.w_2.weight torch.Size([512, 2048])
decoder.transformer_layers.1.feed_forward.w_2.bias torch.Size([512])
decoder.transformer_layers.1.feed_forward.layer_norm.weight torch.Size([512])
decoder.transformer_layers.1.feed_forward.layer_norm.bias torch.Size([512])
decoder.transformer_layers.1.layer_norm_1.weight torch.Size([512])
decoder.transformer_layers.1.layer_norm_1.bias torch.Size([512])
decoder.transformer_layers.1.layer_norm_2.weight torch.Size([512])
decoder.transformer_layers.1.layer_norm_2.bias torch.Size([512])
decoder.transformer_layers.2.mask torch.Size([1, 5000, 5000])
decoder.transformer_layers.2.self_attn.linear_keys.weight torch.Size([512, 512])
decoder.transformer_layers.2.self_attn.linear_keys.bias torch.Size([512])
decoder.transformer_layers.2.self_attn.linear_values.weight torch.Size([512, 512])
decoder.transformer_layers.2.self_attn.linear_values.bias torch.Size([512])
decoder.transformer_layers.2.self_attn.linear_query.weight torch.Size([512, 512])
decoder.transformer_layers.2.self_attn.linear_query.bias torch.Size([512])
decoder.transformer_layers.2.self_attn.final_linear.weight torch.Size([512, 512])
decoder.transformer_layers.2.self_attn.final_linear.bias torch.Size([512])
decoder.transformer_layers.2.context_attn.linear_keys.weight torch.Size([512, 512])
decoder.transformer_layers.2.context_attn.linear_keys.bias torch.Size([512])
decoder.transformer_layers.2.context_attn.linear_values.weight torch.Size([512, 512])
decoder.transformer_layers.2.context_attn.linear_values.bias torch.Size([512])
decoder.transformer_layers.2.context_attn.linear_query.weight torch.Size([512, 512])
decoder.transformer_layers.2.context_attn.linear_query.bias torch.Size([512])
decoder.transformer_layers.2.context_attn.final_linear.weight torch.Size([512, 512])
decoder.transformer_layers.2.context_attn.final_linear.bias torch.Size([512])
decoder.transformer_layers.2.feed_forward.w_1.weight torch.Size([2048, 512])
decoder.transformer_layers.2.feed_forward.w_1.bias torch.Size([2048])
decoder.transformer_layers.2.feed_forward.w_2.weight torch.Size([512, 2048])
decoder.transformer_layers.2.feed_forward.w_2.bias torch.Size([512])
decoder.transformer_layers.2.feed_forward.layer_norm.weight torch.Size([512])
decoder.transformer_layers.2.feed_forward.layer_norm.bias torch.Size([512])
decoder.transformer_layers.2.layer_norm_1.weight torch.Size([512])
decoder.transformer_layers.2.layer_norm_1.bias torch.Size([512])
decoder.transformer_layers.2.layer_norm_2.weight torch.Size([512])
decoder.transformer_layers.2.layer_norm_2.bias torch.Size([512])
decoder.transformer_layers.3.mask torch.Size([1, 5000, 5000])
decoder.transformer_layers.3.self_attn.linear_keys.weight torch.Size([512, 512])
decoder.transformer_layers.3.self_attn.linear_keys.bias torch.Size([512])
decoder.transformer_layers.3.self_attn.linear_values.weight torch.Size([512, 512])
decoder.transformer_layers.3.self_attn.linear_values.bias torch.Size([512])
decoder.transformer_layers.3.self_attn.linear_query.weight torch.Size([512, 512])
decoder.transformer_layers.3.self_attn.linear_query.bias torch.Size([512])
decoder.transformer_layers.3.self_attn.final_linear.weight torch.Size([512, 512])
decoder.transformer_layers.3.self_attn.final_linear.bias torch.Size([512])
decoder.transformer_layers.3.context_attn.linear_keys.weight torch.Size([512, 512])
decoder.transformer_layers.3.context_attn.linear_keys.bias torch.Size([512])
decoder.transformer_layers.3.context_attn.linear_values.weight torch.Size([512, 512])
decoder.transformer_layers.3.context_attn.linear_values.bias torch.Size([512])
decoder.transformer_layers.3.context_attn.linear_query.weight torch.Size([512, 512])
decoder.transformer_layers.3.context_attn.linear_query.bias torch.Size([512])
decoder.transformer_layers.3.context_attn.final_linear.weight torch.Size([512, 512])
decoder.transformer_layers.3.context_attn.final_linear.bias torch.Size([512])
decoder.transformer_layers.3.feed_forward.w_1.weight torch.Size([2048, 512])
decoder.transformer_layers.3.feed_forward.w_1.bias torch.Size([2048])
decoder.transformer_layers.3.feed_forward.w_2.weight torch.Size([512, 2048])
decoder.transformer_layers.3.feed_forward.w_2.bias torch.Size([512])
decoder.transformer_layers.3.feed_forward.layer_norm.weight torch.Size([512])
decoder.transformer_layers.3.feed_forward.layer_norm.bias torch.Size([512])
decoder.transformer_layers.3.layer_norm_1.weight torch.Size([512])
decoder.transformer_layers.3.layer_norm_1.bias torch.Size([512])
decoder.transformer_layers.3.layer_norm_2.weight torch.Size([512])
decoder.transformer_layers.3.layer_norm_2.bias torch.Size([512])
decoder.transformer_layers.4.mask torch.Size([1, 5000, 5000])
decoder.transformer_layers.4.self_attn.linear_keys.weight torch.Size([512, 512])
decoder.transformer_layers.4.self_attn.linear_keys.bias torch.Size([512])
decoder.transformer_layers.4.self_attn.linear_values.weight torch.Size([512, 512])
decoder.transformer_layers.4.self_attn.linear_values.bias torch.Size([512])
decoder.transformer_layers.4.self_attn.linear_query.weight torch.Size([512, 512])
decoder.transformer_layers.4.self_attn.linear_query.bias torch.Size([512])
decoder.transformer_layers.4.self_attn.final_linear.weight torch.Size([512, 512])
decoder.transformer_layers.4.self_attn.final_linear.bias torch.Size([512])
decoder.transformer_layers.4.context_attn.linear_keys.weight torch.Size([512, 512])
decoder.transformer_layers.4.context_attn.linear_keys.bias torch.Size([512])
decoder.transformer_layers.4.context_attn.linear_values.weight torch.Size([512, 512])
decoder.transformer_layers.4.context_attn.linear_values.bias torch.Size([512])
decoder.transformer_layers.4.context_attn.linear_query.weight torch.Size([512, 512])
decoder.transformer_layers.4.context_attn.linear_query.bias torch.Size([512])
decoder.transformer_layers.4.context_attn.final_linear.weight torch.Size([512, 512])
decoder.transformer_layers.4.context_attn.final_linear.bias torch.Size([512])
decoder.transformer_layers.4.feed_forward.w_1.weight torch.Size([2048, 512])
decoder.transformer_layers.4.feed_forward.w_1.bias torch.Size([2048])
decoder.transformer_layers.4.feed_forward.w_2.weight torch.Size([512, 2048])
decoder.transformer_layers.4.feed_forward.w_2.bias torch.Size([512])
decoder.transformer_layers.4.feed_forward.layer_norm.weight torch.Size([512])
decoder.transformer_layers.4.feed_forward.layer_norm.bias torch.Size([512])
decoder.transformer_layers.4.layer_norm_1.weight torch.Size([512])
decoder.transformer_layers.4.layer_norm_1.bias torch.Size([512])
decoder.transformer_layers.4.layer_norm_2.weight torch.Size([512])
decoder.transformer_layers.4.layer_norm_2.bias torch.Size([512])
decoder.transformer_layers.5.mask torch.Size([1, 5000, 5000])
decoder.transformer_layers.5.self_attn.linear_keys.weight torch.Size([512, 512])
decoder.transformer_layers.5.self_attn.linear_keys.bias torch.Size([512])
decoder.transformer_layers.5.self_attn.linear_values.weight torch.Size([512, 512])
decoder.transformer_layers.5.self_attn.linear_values.bias torch.Size([512])
decoder.transformer_layers.5.self_attn.linear_query.weight torch.Size([512, 512])
decoder.transformer_layers.5.self_attn.linear_query.bias torch.Size([512])
decoder.transformer_layers.5.self_attn.final_linear.weight torch.Size([512, 512])
decoder.transformer_layers.5.self_attn.final_linear.bias torch.Size([512])
decoder.transformer_layers.5.context_attn.linear_keys.weight torch.Size([512, 512])
decoder.transformer_layers.5.context_attn.linear_keys.bias torch.Size([512])
decoder.transformer_layers.5.context_attn.linear_values.weight torch.Size([512, 512])
decoder.transformer_layers.5.context_attn.linear_values.bias torch.Size([512])
decoder.transformer_layers.5.context_attn.linear_query.weight torch.Size([512, 512])
decoder.transformer_layers.5.context_attn.linear_query.bias torch.Size([512])
decoder.transformer_layers.5.context_attn.final_linear.weight torch.Size([512, 512])
decoder.transformer_layers.5.context_attn.final_linear.bias torch.Size([512])
decoder.transformer_layers.5.feed_forward.w_1.weight torch.Size([2048, 512])
decoder.transformer_layers.5.feed_forward.w_1.bias torch.Size([2048])
decoder.transformer_layers.5.feed_forward.w_2.weight torch.Size([512, 2048])
decoder.transformer_layers.5.feed_forward.w_2.bias torch.Size([512])
decoder.transformer_layers.5.feed_forward.layer_norm.weight torch.Size([512])
decoder.transformer_layers.5.feed_forward.layer_norm.bias torch.Size([512])
decoder.transformer_layers.5.layer_norm_1.weight torch.Size([512])
decoder.transformer_layers.5.layer_norm_1.bias torch.Size([512])
decoder.transformer_layers.5.layer_norm_2.weight torch.Size([512])
decoder.transformer_layers.5.layer_norm_2.bias torch.Size([512])
decoder.layer_norm.weight torch.Size([512])
decoder.layer_norm.bias torch.Size([512])
generator.0.weight torch.Size([30522, 512])
generator.0.bias torch.Size([30522])