Is it possible for torch SDPA to be slower than manual attention?

I am running into a peculiar issue where using scaled_dot_product_attention results in slower training compared to just implementing the attention in pytorch. The only difference between the two runs is using sdpa vs manual attention. My complete setup is a bit non-trivial to reproduce here and I could not construct a reasonable MWE with the same issue. I am looking for guidance on when sdpa could/would be slower than manual attention, and how I can potentially debug this further. Here’s the relevant part of the code that switches the attention op.

if self.attention_type == AttentionType.SDPA:
    with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH, SDPBackend.CUDNN_ATTENTION]):
        attn_output = scaled_dot_product_attention(
            query=query_states,
            key=key_states,
            value=value_states,
            attn_mask=mask,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=False,
            scale=1.0,
        )
elif self.attention_type == AttentionType.MANUAL:
    scores = torch.matmul(query_states, key_states.transpose(3, 2))
    scores += mask
    attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
    attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
    attn_output = torch.matmul(attn_weights, value_states)

The entire model is torch.compiled before training begins in both cases. However, I noticed that SDPA is slower only when the models are compiled, otherwise it’s either on par or slightly faster than manual attention. In all cases, I see significant memory savings from using SDPA.

EDIT: I tried checking for graph breaks and recompiles in both cases but did not find any obvious issues.

Could you share your profiling code to reproduce the issue, as often synchronizations are missing and the profiles are thus invalid?

@ptrblck The codebase is too large to share unfortunately and I could not reproduce the issue in a MWE. To answer your question: indeed, syncs are missing in my case. However, here I am talking about the end to end training runtime which increases when SDPA is used over manual attention.

I am looking more for potential directions of investigation, if you have any off the top of your head.

For example, could the sparsity of the attention mask somehow result in compiled manual attention being faster than compiled sdpa? I am seeing this in the case where I have document-masking style masks.

OK, in this case I think it would make sense to profile your end2end code to see what exactly changes and where slowdowns might be seen. E.g. you could use Nsight Systems or the native PyTorch profiler to create a timeline showing the GPU kernel execution times, their launches, IDLE time etc.

Thank you @ptrblck. I will try profiling. Meanwhile, I was able to get an MWE up and running. Opened an issue over at Github: SDPA (`EFFICIENT_ATTENTION`) slower than manual attention · Issue #149857 · pytorch/pytorch · GitHub