torch.ao.nn.quantizable.modules.activation.MultiheadAttention not loading the pre-trained model weights correctly

I have a pre-trained model weights saved as .pt and I am using the model for quantization. I replaced the torch.nn.MultiheadAttention to torch.ao.nn.quantizable.modules.activation.MultiheadAttention
but the loss I am getting is 10 times more. The rest of the things are kept same.

I imported the library like this

from torch.ao.nn.quantizable.modules.activation import MultiheadAttention

Then changed the code like this:

        #self.attn = nn.MultiheadAttention(
        self.attn = MultiheadAttention(
            embed_dim,
            num_heads,
            dropout=attn_dropout,
            add_bias_kv=add_bias_kv,
        )

Then importing the weights like this:

pretrained_dict = torch.load("/part-vol-2/weaver-core/particle_transformer/models/ParT_full.pt")

# Load only the parameters that exist in the model
model_dict = pre_trained_model_quant.state_dict()
model_dict.update(pretrained_dict)
pre_trained_model_quant.load_state_dict(model_dict)

Is there anything else that needs to be doe for configuration? Why am I getting this error?

I don’t think we have tested this for training, you can check out how we use this module in pytorch/test/quantization/core/test_quantized_op.py at 53f8a5fde268b51260f136d3b06fc1e86e6912a7 · pytorch/pytorch · GitHub