MultiHeadAttention export to ONNX fails when using torch.no_grad()

Hi!

I’m trying to export to ONNX a model that contains a MultiHeadAttention module. However, I’m running into the following error:

torch.onnx.errors.UnsupportedOperatorError: Exporting the operator ‘aten::_native_multi_head_attention’ to ONNX opset version 17 is not supported.

I created this small wrapper model to replicate the issue:

class TestModel(torch.nn.Module):
    
    def __init__(self):
         super().__init__()
         self.self_attn = torch.nn.MultiheadAttention(
                embed_dim=1024,
                num_heads=8,
                dropout=True,
                batch_first=True,
        )
    
    def forward(self, hidden_states, attention_mask):
        x, _ = self.self_attn(
                    query=hidden_states,
                    key=hidden_states,
                    value=hidden_states,
                    key_padding_mask=attention_mask.bool(),
                    need_weights=False,
                )
        return x

model = TestModel().eval()
batch_size = 16
q = torch.randn((batch_size, 50, 1024))
mask = torch.zeros((batch_size, 50))

with torch.no_grad():
    torch.onnx.export(
        model,
        (q, mask),
        "test-model.onnx",
        input_names=["query", "mask"],
        output_names=["attn_output", "attn_output_weights"],
     )

According to this Github issue it looks like the MHA ONNX operator is not implemented yet.

But the thing that I really don’t get is that this issue is happening only when I use the torch.no_grad() context manager. When I don’t use it the export is a success, why?

I need to use this context manager because otherwise my real model is raising CUDA OOM issue when exporting to ONNX on GPU (more precisely during JIT graph creation in torch.onnx.utils._create_jit_graph() function).

Environment:

I’m using NVIDIA PyTorch NGC container nvcr.io/nvidia/pytorch:24.01-py3 with torch.__version__ = '2.2.0a0+81ea7a4'

I run into the same problem. My current assumption is that when no gradient is needed and certain conditions are met, PyTorch will use the “fast path” implementation, which creates the aten::_native_multi_head_attention node, which is not exportable.

Trying to force PyTorch to not use it was not successful so far… I hope this is fixed in the upcoming 2.3, if not I need to search more for solutions. Currently we stick to PyTorch 1.13 because there the export still works.

You might be able to disable the “fast path” by wrapping your forward pass into the torch.backends.cuda.sdp_kernel context as described here allowing you to disable all optimizations.