Using `flex_attention` with DistributedDataParallel (DDP)

I’ve been experimenting with the new flex_attention module and encountered an issue when trying to integrate it with DistributedDataParallel (DDP). Since flex_attention is a higher-order function, it seems to conflict with DDP’s optimizer.

Below is a minimal example of my current setup:

import os
import time
import math

import torch
from torch.nn.parallel import DistributedDataParallel
from torch.nn.attention.flex_attention import flex_attention

class Model(torch.nn.Module):
    def __init__(self, S, H, D):
        super().__init__()

        self.S = S
        self.H = H
        self.D = D

        alibi_bias = self.generate_alibi_bias(H)
        self.register_buffer("alibi_bias", alibi_bias, persistent=True)
        self.attention = flex_attention

        self.project_qk = torch.nn.Linear(H * D, H * D * 2)
        self.project_v = torch.nn.Linear(H * D, H * D)

    def forward(self, hidden_states):
        batch_size, _, _ = hidden_states.size()

        query, key = self.project_qk(hidden_states).chunk(2, dim=2)
        query = query.view(self.S, batch_size, self.H, self.D)
        query = query.permute(1, 2, 0, 3)

        key = key.view(self.S, batch_size, self.H, self.D)
        key = key.permute(1, 2, 0, 3)

        value = self.project_v(hidden_states)
        value = value.view(self.S, batch_size, self.H, self.D)
        value = value.permute(1, 2, 0, 3)

        return self.attention(query, key, value, score_mod=self.alibi_score_mod)

    def generate_alibi_bias(self, num_heads):
        alibi_bias = [math.exp2(-((i + 1) * 8.0) / num_heads) for i in range(num_heads)]
        return torch.tensor(alibi_bias)

    def alibi_score_mod(self, score, b, h, q_idx, kv_idx):
        bias = (q_idx - kv_idx) * self.alibi_bias[h]
        return score + bias

if __name__ == "__main__":

    B = 64
    H = 12
    S = 512
    D = 64

    rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)

    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)

    model = Model(S, H, D)
    model.to(device)
    model = DistributedDataParallel(model, device_ids=[local_rank])
    torch.compile(model)

    for i in range(100):
        start = time.perf_counter()
        hidden_states = torch.randn(B, S, H * D).to(device)
        attention_scores = model(hidden_states)
        torch.cuda.synchronize()
        print(f"{i}: {time.perf_counter() - start:.4f}")

I run the script using the following command:

torchrun --standalone --nnodes=1 --nproc_per_node=1 flex_attention_test.py

However, I encounter the following error:

[rank0]:   File "/home/colibri/mambaforge/envs/pytorch2_5/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1457, in _call_user_compiler
[rank0]:     raise BackendCompilerFailed(self.compiler_fn, e) from e
[rank0]: torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
[rank0]: NotImplementedError: DDPOptimizer backend: Found a higher order op in the graph. This is not supported. Please turn off DDP optimizer using torch._dynamo.config.optimize_ddp=False. Note that this can cause performance degradation because there will be one bucket for the entire Dynamo graph. Please refer to this issue - https://github.com/pytorch/pytorch/issues/104674.

Disabling the DDP optimizer resolves the error but results in significant performance degradation.

I’m seeking guidance on whether there’s a proper way to use flex_attention or similar higher-order operations in conjunction with DDP without sacrificing performance. Any advice or insights would be greatly appreciated.

1 Like

Also having exact same issue could only ge tit working with disabled ddp optimization.