Hi!
I’m trying to export to ONNX a model that contains a MultiHeadAttention module. However, I’m running into the following error:
torch.onnx.errors.UnsupportedOperatorError: Exporting the operator ‘aten::_native_multi_head_attention’ to ONNX opset version 17 is not supported.
I created this small wrapper model to replicate the issue:
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.self_attn = torch.nn.MultiheadAttention(
embed_dim=1024,
num_heads=8,
dropout=True,
batch_first=True,
)
def forward(self, hidden_states, attention_mask):
x, _ = self.self_attn(
query=hidden_states,
key=hidden_states,
value=hidden_states,
key_padding_mask=attention_mask.bool(),
need_weights=False,
)
return x
model = TestModel().eval()
batch_size = 16
q = torch.randn((batch_size, 50, 1024))
mask = torch.zeros((batch_size, 50))
with torch.no_grad():
torch.onnx.export(
model,
(q, mask),
"test-model.onnx",
input_names=["query", "mask"],
output_names=["attn_output", "attn_output_weights"],
)
According to this Github issue it looks like the MHA ONNX operator is not implemented yet.
But the thing that I really don’t get is that this issue is happening only when I use the torch.no_grad()
context manager. When I don’t use it the export is a success, why?
I need to use this context manager because otherwise my real model is raising CUDA OOM issue when exporting to ONNX on GPU (more precisely during JIT graph creation in torch.onnx.utils._create_jit_graph()
function).
Environment:
I’m using NVIDIA PyTorch NGC container nvcr.io/nvidia/pytorch:24.01-py3 with torch.__version__ = '2.2.0a0+81ea7a4'