I have a pre-trained model weights saved as .pt and I am using the model for quantization. I replaced the torch.nn.MultiheadAttention to torch.ao.nn.quantizable.modules.activation.MultiheadAttention
but the loss I am getting is 10 times more. The rest of the things are kept same.
I imported the library like this
from torch.ao.nn.quantizable.modules.activation import MultiheadAttention
Then changed the code like this:
#self.attn = nn.MultiheadAttention(
self.attn = MultiheadAttention(
embed_dim,
num_heads,
dropout=attn_dropout,
add_bias_kv=add_bias_kv,
)
Then importing the weights like this:
pretrained_dict = torch.load("/part-vol-2/weaver-core/particle_transformer/models/ParT_full.pt")
# Load only the parameters that exist in the model
model_dict = pre_trained_model_quant.state_dict()
model_dict.update(pretrained_dict)
pre_trained_model_quant.load_state_dict(model_dict)
Is there anything else that needs to be doe for configuration? Why am I getting this error?