Wrong Tensor type when using Flash Attention 1.0.9

Hi everyone,

I’m triying to make a project for my lectures and I’m trying to compare the speed of FlashAttention 1.0.9 against the attention layers we implemented in class. Since the lectures are oriented for ASR, we’ve built a Transformer which recognizes the digits of an audio and later sums the pronounced digits. However, when I first tried using FA on my project, an error regarding an assert of the datatype of the tensor appeared (since FA uses float16 and I was using float32). I decided then to set the default type to torch.flooat.16. Nevertheless, the mel log spectogram class which I’m using from complains about this type and I think I’ve encountered a wall.

Are there any other ways of using FA 1.0.9 without impacting the network performance (I suspect that casting each batch into float32 and the casting again to float16 isn’t the most optimal approach)