Error when trying to quantize and save DeBerta model using PyTorch

i am trying to quantize deberta-base and save the quantized model, so i do as follows

from transformers import AutoModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-base")
model = AutoModel.from_pretrained("microsoft/deberta-base", torchscript=True)

quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)

input_ids = torch.tensor([[1, 12196, 16, 5853, 2723, 102, 116, 2]])
token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]])
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]])
traced_model = torch.jit.trace(quantized_model, (input_ids, token_type_ids, attention_mask))
torch.jit.save(traced_model, "traced_deberta.pt")

i get the following error:

RuntimeError: 
Could not export Python function call 'XSoftmax'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
/usr/local/lib/python3.8/site-packages/transformers/models/deberta/modeling_deberta.py(642): forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
/usr/local/lib/python3.8/site-packages/transformers/models/deberta/modeling_deberta.py(280): forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
/usr/local/lib/python3.8/site-packages/transformers/models/deberta/modeling_deberta.py(347): forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
/usr/local/lib/python3.8/site-packages/transformers/models/deberta/modeling_deberta.py(442): forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
/usr/local/lib/python3.8/site-packages/transformers/models/deberta/modeling_deberta.py(954): forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1090): _slow_forward
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py(1102): _call_impl
/usr/local/lib/python3.8/site-packages/torch/jit/_trace.py(958): trace_module
/usr/local/lib/python3.8/site-packages/torch/jit/_trace.py(741): trace

What is the cause of this?

Transformers implements daberta’s masked softmax using a custom autograd function (they say it is in order to save memory, but I am not entirely sure why that would be because they don’t say what the alternative more memory-consuming option was).
The JIT doesn’t support autograd functions (but there is a symbolic ONNX export if you want that).

One option could be to rewrite this bit as masking using stock autograd functions. This seems seems to work reasonably well for regular multihead attention.

Best regards

Thomas