i am trying to quantize deberta-base and save the quantized model, so i do as follows
from transformers import AutoModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-base")
model = AutoModel.from_pretrained("microsoft/deberta-base", torchscript=True)
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
input_ids = torch.tensor([[1, 12196, 16, 5853, 2723, 102, 116, 2]])
token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]])
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]])
traced_model = torch.jit.trace(quantized_model, (input_ids, token_type_ids, attention_mask))
torch.jit.save(traced_model, "traced_deberta.pt")
i get the following error:
RuntimeError:
Could not export Python function call 'XSoftmax'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
/usr/local/lib/python3.8/site-packages/transformers/models/deberta/modeling_deberta.py(642): forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
/usr/local/lib/python3.8/site-packages/transformers/models/deberta/modeling_deberta.py(280): forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
/usr/local/lib/python3.8/site-packages/transformers/models/deberta/modeling_deberta.py(347): forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
/usr/local/lib/python3.8/site-packages/transformers/models/deberta/modeling_deberta.py(442): forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
/usr/local/lib/python3.8/site-packages/transformers/models/deberta/modeling_deberta.py(954): forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
/usr/local/lib/python3.8/site-packages/torch/jit/_trace.py(958): trace_module
/usr/local/lib/python3.8/site-packages/torch/jit/_trace.py(741): trace
What is the cause of this?