When I use torch-tenosrrt speed up the inference, the multihead attention module can not run in fp16 mode. Here is a code:
class MyModel(torch.nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
super().__init__()
self.attn = nn.MultiheadAttention(n_heads * d_head, n_heads)
def forward(self, x):
x = x.transpose(0, 1)
x = self.attn(x, x, x, need_weights=False)[0].transpose(0, 1) + x
return x
First, I use torch.jit.trace trace the model, and then use torch_tensorrt.compile compile the model. Even though I set enabled_precisions= {torch_tensorrt.dtype.half}, the multihead attention runs in fp32 mode. And If I use with autocast(“cuda”, torch.float16) run the torchtrt model, I will get an error.