MultiHeadAttention export to ONNX fails when using torch.no_grad()

Hi!

I’m trying to export to ONNX a model that contains a MultiHeadAttention module. However, I’m running into the following error:

torch.onnx.errors.UnsupportedOperatorError: Exporting the operator ‘aten::_native_multi_head_attention’ to ONNX opset version 17 is not supported.

I created this small wrapper model to replicate the issue:

class TestModel(torch.nn.Module):
    
    def __init__(self):
         super().__init__()
         self.self_attn = torch.nn.MultiheadAttention(
                embed_dim=1024,
                num_heads=8,
                dropout=True,
                batch_first=True,
        )
    
    def forward(self, hidden_states, attention_mask):
        x, _ = self.self_attn(
                    query=hidden_states,
                    key=hidden_states,
                    value=hidden_states,
                    key_padding_mask=attention_mask.bool(),
                    need_weights=False,
                )
        return x

model = TestModel().eval()
batch_size = 16
q = torch.randn((batch_size, 50, 1024))
mask = torch.zeros((batch_size, 50))

with torch.no_grad():
    torch.onnx.export(
        model,
        (q, mask),
        "test-model.onnx",
        input_names=["query", "mask"],
        output_names=["attn_output", "attn_output_weights"],
     )

According to this Github issue it looks like the MHA ONNX operator is not implemented yet.

But the thing that I really don’t get is that this issue is happening only when I use the torch.no_grad() context manager. When I don’t use it the export is a success, why?

I need to use this context manager because otherwise my real model is raising CUDA OOM issue when exporting to ONNX on GPU (more precisely during JIT graph creation in torch.onnx.utils._create_jit_graph() function).

Environment:

I’m using NVIDIA PyTorch NGC container nvcr.io/nvidia/pytorch:24.01-py3 with torch.__version__ = '2.2.0a0+81ea7a4'