You are explicitly using the GPU via:
with torch.cuda.device
so did you check if flash_attn supports CPU-only workloads?
You are explicitly using the GPU via:
with torch.cuda.device
so did you check if flash_attn supports CPU-only workloads?