Using `flex_attention` with DistributedDataParallel (DDP)

I’ve been experimenting with the new flex_attention module and encountered an issue when trying to integrate it with DistributedDataParallel (DDP). Since flex_attention is a higher-order function, it seems to conflict with DDP’s optimizer.

Below is a minimal example of my current setup:

import os
import time
import math

import torch
from torch.nn.parallel import DistributedDataParallel
from torch.nn.attention.flex_attention import flex_attention

class Model(torch.nn.Module):
    def __init__(self, S, H, D):

        self.S = S
        self.H = H
        self.D = D

        alibi_bias = self.generate_alibi_bias(H)
        self.register_buffer("alibi_bias", alibi_bias, persistent=True)
        self.attention = flex_attention

        self.project_qk = torch.nn.Linear(H * D, H * D * 2)
        self.project_v = torch.nn.Linear(H * D, H * D)

    def forward(self, hidden_states):
        batch_size, _, _ = hidden_states.size()

        query, key = self.project_qk(hidden_states).chunk(2, dim=2)
        query = query.view(self.S, batch_size, self.H, self.D)
        query = query.permute(1, 2, 0, 3)

        key = key.view(self.S, batch_size, self.H, self.D)
        key = key.permute(1, 2, 0, 3)

        value = self.project_v(hidden_states)
        value = value.view(self.S, batch_size, self.H, self.D)
        value = value.permute(1, 2, 0, 3)

        return self.attention(query, key, value, score_mod=self.alibi_score_mod)

    def generate_alibi_bias(self, num_heads):
        alibi_bias = [math.exp2(-((i + 1) * 8.0) / num_heads) for i in range(num_heads)]
        return torch.tensor(alibi_bias)

    def alibi_score_mod(self, score, b, h, q_idx, kv_idx):
        bias = (q_idx - kv_idx) * self.alibi_bias[h]
        return score + bias

if __name__ == "__main__":

    B = 64
    H = 12
    S = 512
    D = 64

    rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)

    device = torch.device("cuda", local_rank)

    model = Model(S, H, D)
    model = DistributedDataParallel(model, device_ids=[local_rank])

    for i in range(100):
        start = time.perf_counter()
        hidden_states = torch.randn(B, S, H * D).to(device)
        attention_scores = model(hidden_states)
        print(f"{i}: {time.perf_counter() - start:.4f}")

I run the script using the following command:

torchrun --standalone --nnodes=1 --nproc_per_node=1

However, I encounter the following error:

[rank0]:   File "/home/colibri/mambaforge/envs/pytorch2_5/lib/python3.11/site-packages/torch/_dynamo/", line 1457, in _call_user_compiler
[rank0]:     raise BackendCompilerFailed(self.compiler_fn, e) from e
[rank0]: torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
[rank0]: NotImplementedError: DDPOptimizer backend: Found a higher order op in the graph. This is not supported. Please turn off DDP optimizer using torch._dynamo.config.optimize_ddp=False. Note that this can cause performance degradation because there will be one bucket for the entire Dynamo graph. Please refer to this issue -

Disabling the DDP optimizer resolves the error but results in significant performance degradation.

I’m seeking guidance on whether there’s a proper way to use flex_attention or similar higher-order operations in conjunction with DDP without sacrificing performance. Any advice or insights would be greatly appreciated.


Also having exact same issue could only ge tit working with disabled ddp optimization.


Also having exactly the same issue