I am running into a peculiar issue where using scaled_dot_product_attention
results in slower training compared to just implementing the attention in pytorch. The only difference between the two runs is using sdpa vs manual attention. My complete setup is a bit non-trivial to reproduce here and I could not construct a reasonable MWE with the same issue. I am looking for guidance on when sdpa could/would be slower than manual attention, and how I can potentially debug this further. Here’s the relevant part of the code that switches the attention op.
if self.attention_type == AttentionType.SDPA:
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH, SDPBackend.CUDNN_ATTENTION]):
attn_output = scaled_dot_product_attention(
query=query_states,
key=key_states,
value=value_states,
attn_mask=mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=False,
scale=1.0,
)
elif self.attention_type == AttentionType.MANUAL:
scores = torch.matmul(query_states, key_states.transpose(3, 2))
scores += mask
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
The entire model is torch.compile
d before training begins in both cases. However, I noticed that SDPA is slower only when the models are compiled, otherwise it’s either on par or slightly faster than manual attention. In all cases, I see significant memory savings from using SDPA.
EDIT: I tried checking for graph breaks and recompiles in both cases but did not find any obvious issues.