TorchTensorRT: multihead attention can not run in fp16 mode

When I use torch-tenosrrt speed up the inference, the multihead attention module can not run in fp16 mode. Here is a code:

class MyModel(torch.nn.Module):
    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
              self.attn = nn.MultiheadAttention(n_heads * d_head, n_heads) 

    def forward(self, x):
        x = x.transpose(0, 1)
        x = self.attn(x, x, x, need_weights=False)[0].transpose(0, 1) + x
        return x

First, I use torch.jit.trace trace the model, and then use torch_tensorrt.compile compile the model. Even though I set enabled_precisions= {torch_tensorrt.dtype.half}, the multihead attention runs in fp32 mode. And If I use with autocast(“cuda”, torch.float16) run the torchtrt model, I will get an error.