RuntimeError: No available kernel. Flash Attention Appears Not to Work with Pipeline Parallelism

Issue

For distributed training, I’d like to use Pipeline Parallelism + Distributed Data Parallelism + Flash Attention; however, Pipeline Parallelism appears not to work with Flash Attention. Specifically, I get this error when I try to use Flash Attention within a Pipeline Parallel model: RuntimeError: No available kernel. Aborting execution.

Question

My assumption is that Flash Attention should be compatible with Pipeline Parallelism. Therefore I’m wondering

  1. Am I doing something wrong that’s causing Flash Attention to fail when using Pipeline Parallelism?
  2. If Flash Attention is in fact incompatible with Pipeline Parallelism, are there any plans to support Pipeline Parallelism + Flash Attention in the future?

Reproducing the issue

Below I include a script to reproduce my issue. To run the script I use this torchrun command: torchrun --standalone --nproc_per_node=num_gpus -m my_script.

When parallel=“pdp” is passed as an argument to my main function, the script will run a forward pass on a toy model with Pipeline Parallelism + Distributed Data Parallelism + Flash Attention. If you change parallel=“ddp”, it will run the forward pass with just Distributed Data Parallelism + Flash Attention.

This way you can see that the pdp setup causes flash attention to fail with the error RuntimeError: No available kernel. Aborting execution, whereas “ddp” works fine.

import os
import math
from time import sleep
from datetime import timedelta
from collections import defaultdict

import torch
import torch.nn as nn
import torch.amp as amp
import torch.nn.functional as F
from torch.distributed.pipeline.sync import Pipe
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import rpc, init_process_group, new_group, Backend, destroy_process_group, barrier


class FlashAttention(nn.Module):

    def __init__(self, d_model: int, n_heads: int, dropout: float) -> None:
        super().__init__()
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = dropout

    def forward(self, x) -> torch.Tensor:
        """
        x: (N, T, D)
        """
        n, t, _ = x.shape
        # (N, H, T, E)
        q = self.W_q(x).view(n, t, -1, self.d_head).transpose(1, 2)
        k = self.W_k(x).view(n, t, -1, self.d_head).transpose(1, 2)
        v = self.W_v(x).view(n, t, -1, self.d_head).transpose(1, 2)
        dropout = self.dropout if self.training else 0.0
        with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
            # (N, H, T, E)
            a = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=dropout)
        # (N, T, d_model)
        o = self.W_o(a.transpose(1, 2).reshape(n, t, self.n_heads * self.d_head))
        return o


class ToyModel(nn.Module):

    def __init__(self, vocab_size=100, d_model=768, n_heads=12, dropout=0.1):
        super().__init__()
        self.vocab_size = vocab_size
        embed = nn.Embedding(vocab_size, d_model)
        attn1 = FlashAttention(d_model, n_heads, dropout)
        attn2 = FlashAttention(d_model, n_heads, dropout)
        out = nn.Linear(d_model, vocab_size)
        self.layers = nn.Sequential(embed, attn1, attn2, out)
        self.loss_fn = F.cross_entropy

    def as_sequential(self):
        return self.layers

    def forward(self, x, y=None):
        logits = self.layers(x)
        if y is not None:
            loss = self.loss_fn(
                input=logits.view(-1, self.vocab_size),
                target=y.view(-1),
            )
        return logits, loss


class PipelineParallel(nn.Module):

    def __init__(self, pipeline, loss_fn, vocab_size, ignore_index=-1):
        super().__init__()
        self.pipeline = pipeline
        self.loss_fn = loss_fn
        self.vocab_size = vocab_size
        self.ignore_index = ignore_index

    def __call__(self, x, y=None):
        loss = None
        logits = self.pipeline(x).local_value()

        if y is not None:
            loss = self.loss_fn(
                input=logits.view(-1, self.vocab_size),
                target=y.view(-1),
                ignore_index=self.ignore_index,
            )
        return logits, loss

    def make_data_parallel(self, process_group):
        self.pipeline = DDP(self.pipeline, process_group=process_group)

    @classmethod
    def from_sequential_model(
        cls,
        model,
        num_pipeline_stages,
        num_micro_batches,
        activation_ckpt_mode,
        deferred_batch_norm,
    ):
        """
        Approach copied from https://medium.com/pytorch/pytorch-data-parallel-best-practices-on-google-cloud-6c8da2be180d
        """
        local_rank = int(os.environ["LOCAL_RANK"])
        first_stage_rank = (local_rank // num_pipeline_stages) * num_pipeline_stages

        stage_idx = 0
        num_stage_params = 0
        stage_to_layers = defaultdict(list)
        sequential = model.as_sequential()
        num_model_params = sum([p.numel() for p in sequential.parameters()])
        max_params_per_stage = math.ceil(num_model_params / num_pipeline_stages)

        # partition the model into stages
        for layer in sequential:
            num_layer_params = sum([p.numel() for p in layer.parameters()])
            is_stage_full = num_stage_params + num_layer_params > max_params_per_stage
            is_last_stage = stage_idx == num_pipeline_stages - 1
            if is_stage_full and not is_last_stage:
                stage_idx += 1
                num_stage_params = num_layer_params
            else:
                num_stage_params += num_layer_params
            stage_to_layers[stage_idx].append(layer)

        # put each stage onto a GPU
        for i, stage in stage_to_layers.items():
            device = f"cuda:{first_stage_rank+i}"
            for layer in stage:
                layer.to(device=device)

        assert len(stage_to_layers) == num_pipeline_stages
        pipeline = Pipe(
            module=nn.Sequential(*[nn.Sequential(*stage_to_layers[j]) for j in range(num_pipeline_stages)]),
            chunks=num_micro_batches,
            checkpoint=activation_ckpt_mode,
            deferred_batch_norm=deferred_batch_norm,
        )
        return cls(pipeline=pipeline, loss_fn=model.loss_fn, vocab_size=model.vocab_size)


def main(num_pipeline_stages=2, batch_size=4, parallel="pdp"):

    rank = int(os.environ['RANK'])
    local_rank = int(os.environ['LOCAL_RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    master_addr = os.environ['MASTER_ADDR']
    master_port = os.environ['MASTER_PORT']

    if parallel == "pdp":
        init_process_group(
            init_method='tcp://' + str(master_addr) + ':' + str(master_port),
            backend=Backend.GLOO, rank=rank, world_size=world_size
        )
        rpc.init_rpc(
            "worker:" + str(rank),
            rank=rank,
            world_size=world_size,
        )
        driver_ranks = [i for i in range(world_size) if i % num_pipeline_stages == 0]
        process_group = new_group(ranks=driver_ranks, backend=Backend.NCCL, timeout=timedelta(days=365))
    elif parallel == "ddp":
        process_group = None
        init_process_group(rank=rank, world_size=world_size, backend=Backend.NCCL)

    vocab_size = 100
    d_model = 768
    seqlen = 10

    rand_ints = torch.randint(0, vocab_size, (batch_size, seqlen+1))
    x, y = rand_ints[:, :-1], rand_ints[:, 1:]
    amp_context = amp.autocast(device_type="cuda", dtype=torch.bfloat16)

    if parallel == "ddp":
        model = ToyModel(vocab_size=vocab_size, d_model=d_model)
        model = model.to(local_rank)
        model = DDP(model, device_ids=[local_rank])
        with amp_context:
            logits, loss = model(x.to(local_rank), y.to(local_rank))

    elif parallel == "pdp" and rank % num_pipeline_stages == 0:
        input_device = "cuda:0"
        output_device = "cuda:1"
        model = ToyModel(vocab_size=vocab_size, d_model=d_model)
        model = PipelineParallel.from_sequential_model(model, num_pipeline_stages, 2, "never", False)
        model.make_data_parallel(process_group=process_group)
        with amp_context:
            logits, loss = model(x.to(input_device), y.to(output_device))

    barrier(process_group)
    if parallel == "pdp":
        rpc.shutdown(graceful=True)
    destroy_process_group()


if __name__ == "__main__":
    main()