Hello, I try to implement my own neural machine translition model with Flash Attention (use scaled_dot_product_attention from torch.nn.functional)
I have two troubles with it:
- When I wanna use dtype=torch.float16, I have the following error:
RuntimeError: "baddbmm_with_gemm" not implemented for 'Half'
- When I try to use device = ‘cuda’ I have this error:
RuntimeError: No available kernel. Aborting execution.
For 2 point: I found this solution, upgrade pytorch for version from post, but it didn’t help me.
Leave here the piece of code with error:
if self.flash:
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
y = F.scaled_dot_product_attention(Q, K, V,
attn_mask=mask,
dropout_p=self.dropout if self.training else 0,
is_causal=True)
else:
raise ImportError("PyTorch >= 2.0 must be installed for using Flash Attention")
I can post the full function code if necessary