Cuda graphs kv_cache implementation slows inference significantly

I’m trying to reimplement Whisper inference with the performance tricks inspired by gpt-fast:

I’m going to get round to implementing int8 quantisation for the linear layers. However as a first try I wanted to see how far I could get with cuda graphs and torch.compile, as they describe in the blogpost. The main issue I’m having is that the ‘improved’ cuda-graph friendly kv cache implementation in gpt-fast seems to be significantly slowing down my inference compared to my initial naive implementation. I think I’m seriously misunderstanding something.

I’ve taken OpenAI’s whisper implementation (whisper/whisper/model.py at main · openai/whisper · GitHub) and have modified is so that it is cuda-graph friendly, i.e torch.compile(mode=“reduce-overhead”, fullgraph=True) doesn’t throw an error. I reimplemented the kv cache using the gpt-fast method as well as a cuda-graph unfriendly dynamic memory allocation method. These are both in the following code:

import torch
import torch.nn.functional as F

from torch import Tensor, nn
from dataclasses import dataclass
from typing import Dict, Iterable, Optional


@dataclass
class ModelConfig:
    n_mels: int
    n_audio_ctx: int
    n_audio_state: int
    n_audio_head: int
    n_audio_layer: int
    n_text_ctx: int
    n_text_state: int
    n_text_head: int
    n_text_layer: int
    n_vocab: Optional[int] = None

    def set_vocab_size(self, vocab_size: int):
        self.n_vocab = vocab_size


class KVCache(nn.Module):
    def __init__(
        self,
        max_batch_size: int,
        max_seq_length: int,
        n_heads: int,
        head_dim: int,
        dtype=torch.bfloat16,
    ):
        super().__init__()

        # New
        cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
        self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
        self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
        # End New

        # Old
        # self.k_cache = None
        # self.v_cache = None
        # End Old

    def get_cache(self):
        # Only intended to be used for cross attention
        return self.k_cache, self.v_cache

    def update(self, input_pos, k_val, v_val):
        # input_pos: [S], k_val, v_val: [B, H, L, D]
        k_out = self.k_cache
        v_out = self.v_cache
        k_out[:, :, input_pos] = k_val
        v_out[:, :, input_pos] = v_val

        return k_out, v_out

    def update_old(self, input_pos, k_val, v_val):
        if self.k_cache is None:
            self.k_cache, self.v_cache = k_val, v_val
        else:
            k_val = torch.cat((self.k_cache, k_val), dim=2).detach()
            v_val = torch.cat((self.v_cache, v_val), dim=2).detach()
            self.k_cache, self.v_cache = k_val, v_val

        return k_val, v_val


def sinusoids(
    length: int, channels: int, max_timescale: float = 10000
) -> torch.Tensor:
    """Returns sinusoids for positional embedding"""
    if channels % 2 != 0:
        raise ValueError(
            f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
        )
    log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(
        -log_timescale_increment * torch.arange(channels // 2)
    )
    scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
    return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)


class EncoderAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        assert n_state % n_head == 0, "n_head does not evenly devide n_state"

        self.n_head = n_head
        self.d_head = n_state // n_head
        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)

    def forward(
        self,
        xa: Tensor,
    ):
        q = self.query(xa)
        k = self.key(xa)
        v = self.value(xa)

        # Reshape for correct format
        batch_size, source_seq_len, _ = k.shape
        batch_size, target_seq_len, _ = q.shape
        q = q.view(
            batch_size, target_seq_len, self.n_head, self.d_head
        ).transpose(1, 2)
        k = k.view(
            batch_size, source_seq_len, self.n_head, self.d_head
        ).transpose(1, 2)
        v = v.view(
            batch_size, source_seq_len, self.n_head, self.d_head
        ).transpose(1, 2)

        wv = F.scaled_dot_product_attention(
            query=q,
            key=k,
            value=v,
            is_causal=False,
        )
        wv = wv.transpose(1, 2).view(
            batch_size,
            target_seq_len,
            self.n_head * self.d_head,
        )

        return self.out(wv)


class CrossAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        assert n_state % n_head == 0, "n_head does not evenly devide n_state"

        self.n_head = n_head
        self.d_head = n_state // n_head
        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)
        self.kv_cache: KVCache | None = None

    def prefill_kv(self, xa: torch.Tensor):
        assert self.kv_cache is not None, "No kv_cache"
        k = self.key(xa)
        v = self.value(xa)

        # Reshape for correct format
        batch_size, source_seq_len, _ = k.shape
        k = k.view(
            batch_size, source_seq_len, self.n_head, self.d_head
        ).transpose(1, 2)
        v = v.view(
            batch_size, source_seq_len, self.n_head, self.d_head
        ).transpose(1, 2)

        self.kv_cache.k_cache = k
        self.kv_cache.v_cache = v

        return k, v

    def forward(
        self,
        x: Tensor,
    ):
        q = self.query(x)
        batch_size, target_seq_len, _ = q.shape
        q = q.view(
            batch_size, target_seq_len, self.n_head, self.d_head
        ).transpose(1, 2)
        k, v = self.kv_cache.get_cache()

        wv = F.scaled_dot_product_attention(
            query=q,
            key=k,
            value=v,
            is_causal=False,
        )
        wv = wv.transpose(1, 2).view(
            batch_size,
            target_seq_len,
            self.n_head * self.d_head,
        )

        return self.out(wv)


class CausalSelfAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        assert n_state % n_head == 0, "n_head does not evenly devide n_state"

        self.n_head = n_head
        self.d_head = n_state // n_head
        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)
        self.kv_cache: KVCache | None = None

    def get_kv(self, x: torch.Tensor, input_pos: torch.Tensor):
        # Self attn
        k = self.key(x)
        v = self.value(x)

        # Reshape
        batch_size, source_seq_len, _ = k.shape
        k = k.view(
            batch_size, source_seq_len, self.n_head, self.d_head
        ).transpose(1, 2)
        v = v.view(
            batch_size, source_seq_len, self.n_head, self.d_head
        ).transpose(1, 2)

        # New
        k, v = self.kv_cache.update(k_val=k, v_val=v, input_pos=input_pos)
        # End New

        # Old
        # k, v = self.kv_cache.update_old(k_val=k, v_val=v, input_pos=input_pos)
        # End Old

        return k, v

    def forward(
        self,
        x: Tensor,
        mask: Optional[Tensor] = None,
        input_pos: Optional[Tensor] = None,
    ):
        q = self.query(x)

        batch_size, target_seq_len, _ = q.shape
        q = q.view(
            batch_size, target_seq_len, self.n_head, self.d_head
        ).transpose(1, 2)

        k, v = self.get_kv(x, input_pos=input_pos)

        # New
        wv = F.scaled_dot_product_attention(
            query=q,
            key=k,
            value=v,
            attn_mask=mask,
        )
        # End New

        # Old
        # wv = F.scaled_dot_product_attention(
        #     query=q,
        #     key=k,
        #     value=v,
        #     is_causal=False, # This fine since we never prefill
        # )
        # End Old

        # (bz, nh, L, dh) -> (bz, L, nh, dh) -> (bz, L, d)
        wv = wv.transpose(1, 2).view(
            batch_size, target_seq_len, self.n_head * self.d_head
        )

        return self.out(wv)


class EncoderAttentionBlock(nn.Module):
    def __init__(
        self, n_state: int, n_head: int, cross_attention: bool = False
    ):
        super().__init__()
        self.attn = EncoderAttention(n_state, n_head)
        self.attn_ln = nn.LayerNorm(n_state)
        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
        )
        self.mlp_ln = nn.LayerNorm(n_state)

    def forward(
        self,
        xa: Tensor,
    ):
        xa = xa + self.attn(
            self.attn_ln(xa),
        )
        xa = xa + self.mlp(self.mlp_ln(xa))

        return xa


class DecoderAttentionBlock(nn.Module):
    def __init__(
        self, n_state: int, n_head: int, cross_attention: bool = False
    ):
        super().__init__()
        self.attn = CausalSelfAttention(n_state, n_head)
        self.attn_ln = nn.LayerNorm(n_state)
        self.cross_attn = (
            CrossAttention(n_state, n_head) if cross_attention else None
        )
        self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None

        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
        )
        self.mlp_ln = nn.LayerNorm(n_state)

    def forward(
        self,
        x: Tensor,
        mask: Optional[Tensor] = None,
        input_pos: Optional[Tensor] = None,
    ):
        x = x + self.attn(
            self.attn_ln(x),
            mask=mask,
            input_pos=input_pos,
        )
        x = x + self.cross_attn(self.cross_attn_ln(x))
        x = x + self.mlp(self.mlp_ln(x))

        return x


class AudioEncoder(nn.Module):
    def __init__(
        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()
        self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(
            n_state, n_state, kernel_size=3, stride=2, padding=1
        )
        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))

        self.blocks: Iterable[EncoderAttentionBlock] = nn.ModuleList(
            [EncoderAttentionBlock(n_state, n_head) for _ in range(n_layer)]
        )
        self.ln_post = nn.LayerNorm(n_state)

    def forward(self, xa: Tensor):
        """
        x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
            the mel spectrogram of the audio
        """
        xa = F.gelu(self.conv1(xa))
        xa = F.gelu(self.conv2(xa))
        xa = xa.permute(0, 2, 1)

        assert (
            xa.shape[1:] == self.positional_embedding.shape
        ), f"incorrect audio shape: {xa.shape[1:]} != {self.positional_embedding.shape}"
        xa = (xa + self.positional_embedding).to(xa.dtype)

        for block in self.blocks:
            xa = block(xa)

        xa = self.ln_post(xa)
        return xa


class TextDecoder(nn.Module):
    def __init__(
        self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()
        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))

        self.blocks: Iterable[DecoderAttentionBlock] = nn.ModuleList(
            [
                DecoderAttentionBlock(n_state, n_head, cross_attention=True)
                for _ in range(n_layer)
            ]
        )
        self.ln = nn.LayerNorm(n_state)
        self.register_buffer("causal_mask", None, persistent=False)

    # Needs to be adjusted for train
    def forward(
        self,
        x: Tensor,
        input_pos: Tensor | None = None,
    ):
        """
        x : torch.LongTensor, shape = (batch_size, <= n_ctx)
            the text tokens
        """
        # Works for batched inference
        mask = self.causal_mask[None, None, input_pos]
        x = self.token_embedding(x) + self.positional_embedding[input_pos]

        for block in self.blocks:
            x = block(x, mask=mask, input_pos=input_pos)

        x = self.ln(x)
        logits = (
            x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
        ).float()

        return logits

    def setup_cache(
        self,
        xa: Tensor,
        batch_size,
        max_seq_len=4096,
        max_audio_len=1500,
    ):
        # Resets prefills kv_cache
        self.causal_mask = torch.tril(
            torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
        )
        # Init cache
        for b in self.blocks:
            b.attn.kv_cache = KVCache(
                max_batch_size=batch_size,
                max_seq_length=max_seq_len,
                n_heads=8,
                head_dim=64,
            ).cuda()

            b.cross_attn.kv_cache = KVCache(
                max_batch_size=batch_size,
                max_seq_length=max_audio_len,
                n_heads=8,
                head_dim=64,
            ).cuda()
            b.cross_attn.prefill_kv(xa)


class AmtEncoderDecoder(nn.Module):
    def __init__(self, dims: ModelConfig):
        super().__init__()
        self.dims = dims
        self.encoder = AudioEncoder(
            self.dims.n_mels,
            self.dims.n_audio_ctx,
            self.dims.n_audio_state,
            self.dims.n_audio_head,
            self.dims.n_audio_layer,
        )
        self.decoder = TextDecoder(
            self.dims.n_vocab,
            self.dims.n_text_ctx,
            self.dims.n_text_state,
            self.dims.n_text_head,
            self.dims.n_text_layer,
        )

    @property
    def device(self):
        return next(self.parameters()).device

And my inference code is just doing greedy decoding, something like

    audio_features = model.encoder(xa=log_mels)
    model.decoder.setup_cache(xa=audio_features, batch_size=seq.shape[0])
    model.cuda()

    for idx in (
        pbar := tqdm(
            range(min_prefix_len, MAX_SEQ_LEN - 1),
            total=MAX_SEQ_LEN - (min_prefix_len + 1),
            leave=False,
        )
    ):
        if idx == min_prefix_len:
            logits = model.decoder(
                x=seq[:, :idx],
                input_pos=torch.arange(0, idx, device=seq.device),
            )
        else:
            logits = model.decoder(
                x=seq[:, idx - 1 : idx],
                input_pos=torch.tensor(
                    [idx], device=seq.device, dtype=torch.int
                ),
            )

        seq[:, idx] = torch.argmax(logits[:, -1], dim=-1)

When using the dynamic implementation without torch.compile(), I can get around 450+ tok/s on my 4090 with batch_size=1. When using my implementation of the static implementation, I get around 290 toks/s without torch.compile() and 320 with torch.compile(mode"reduce-overhead", fullgraph=True).

I’ve obviously implemented something wrong, however I’m having a hard time tracking it down. I’m pretty sure the issue is either that the attention calculation is not realising that it can skip a bunch of the computation due to the attention mask, or that updating and returning the statically allocated kv takes longer than using torch.cat.

Any advice would be appraised.

Here are the pytorch profiles (sorted by cpu and the cuda) from a generation cycle starting from the BOS token to the EOS token without torch.compile().

Old (dynamic kv cache alloc):

                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference        15.28%     587.265ms       100.00%        3.844s        3.844s       0.000us         0.00%     591.835ms     591.835ms             1  
                                           aten::linear         8.53%     328.038ms        48.16%        1.851s      15.543us       0.000us         0.00%     367.562ms       3.086us        119104  
                     aten::scaled_dot_product_attention         3.78%     145.199ms        18.39%     706.801ms      31.782us       0.000us         0.00%     170.576ms       7.670us         22239  
std::enable_if<!(false), void>::type internal::gemvx...         0.00%       0.000us         0.00%       0.000us       0.000us     168.268ms        30.47%     168.268ms       2.479us         67872  
                                            aten::addmm        10.13%     389.251ms        12.56%     482.742ms       8.111us     153.624ms        27.82%     162.017ms       2.722us         59520  
              aten::_scaled_dot_product_flash_attention         2.20%      84.565ms        12.79%     491.691ms      28.937us       0.000us         0.00%     150.150ms       8.837us         16992  
                         aten::_flash_attention_forward         5.90%     226.868ms        10.17%     390.889ms      23.004us     141.721ms        25.66%     150.150ms       8.837us         16992  
void pytorch_flash::flash_fwd_splitkv_kernel<pytorch...         0.00%       0.000us         0.00%       0.000us       0.000us      94.042ms        17.03%      94.042ms       6.768us         13896  
                                            aten::copy_         3.17%     121.922ms         5.00%     192.307ms       3.950us      77.508ms        14.03%      84.659ms       1.739us         48689  
                                         aten::_to_copy         2.53%      97.347ms         8.85%     340.313ms       7.148us       0.000us         0.00%      75.312ms       1.582us         47608  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      73.022ms        13.22%      73.022ms       1.648us         44312  
                                       aten::layer_norm         2.73%     104.974ms        14.11%     542.330ms      14.465us       0.000us         0.00%      69.164ms       1.845us         37493  
                                               aten::to         1.62%      62.183ms         9.55%     367.196ms       7.399us       0.000us         0.00%      69.105ms       1.392us         49629  
                                aten::native_layer_norm         5.15%     197.965ms         8.97%     344.805ms      13.156us      52.568ms         9.52%      56.809ms       2.168us         26209  
                                              aten::add         2.94%     112.852ms         3.71%     142.606ms       5.441us      52.169ms         9.45%      55.625ms       2.122us         26210  
void at::native::(anonymous namespace)::vectorized_l...         0.00%       0.000us         0.00%       0.000us       0.000us      52.568ms         9.52%      52.568ms       2.006us         26209  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      50.904ms         9.22%      50.904ms       2.000us         25452  
                                       cudaLaunchKernel        13.82%     531.137ms        13.82%     531.137ms       2.311us      31.410ms         5.69%      32.558ms       0.142us        229814  
                                           aten::matmul         0.65%      25.147ms         6.14%     236.129ms      23.765us       0.000us         0.00%      31.122ms       3.132us          9936  
                                              aten::cat         2.08%      80.087ms         2.73%     105.066ms       6.200us      28.032ms         5.08%      30.345ms       1.791us         16946  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 3.844s
Self CUDA time total: 552.302ms

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference        15.28%     587.265ms       100.00%        3.844s        3.844s       0.000us         0.00%     591.835ms     591.835ms             1  
                                           aten::linear         8.53%     328.038ms        48.16%        1.851s      15.543us       0.000us         0.00%     367.562ms       3.086us        119104  
                     aten::scaled_dot_product_attention         3.78%     145.199ms        18.39%     706.801ms      31.782us       0.000us         0.00%     170.576ms       7.670us         22239  
                                       aten::layer_norm         2.73%     104.974ms        14.11%     542.330ms      14.465us       0.000us         0.00%      69.164ms       1.845us         37493  
                                       cudaLaunchKernel        13.82%     531.137ms        13.82%     531.137ms       2.311us      31.410ms         5.69%      32.558ms       0.142us        229814  
              aten::_scaled_dot_product_flash_attention         2.20%      84.565ms        12.79%     491.691ms      28.937us       0.000us         0.00%     150.150ms       8.837us         16992  
                                            aten::addmm        10.13%     389.251ms        12.56%     482.742ms       8.111us     153.624ms        27.82%     162.017ms       2.722us         59520  
                         aten::_flash_attention_forward         5.90%     226.868ms        10.17%     390.889ms      23.004us     141.721ms        25.66%     150.150ms       8.837us         16992  
                                               aten::to         1.62%      62.183ms         9.55%     367.196ms       7.399us       0.000us         0.00%      69.105ms       1.392us         49629  
                                aten::native_layer_norm         5.15%     197.965ms         8.97%     344.805ms      13.156us      52.568ms         9.52%      56.809ms       2.168us         26209  
                                         aten::_to_copy         2.53%      97.347ms         8.85%     340.313ms       7.148us       0.000us         0.00%      75.312ms       1.582us         47608  
                                           aten::matmul         0.65%      25.147ms         6.14%     236.129ms      23.765us       0.000us         0.00%      31.122ms       3.132us          9936  
                                            aten::empty         5.10%     196.126ms         5.13%     197.136ms       1.119us       0.000us         0.00%      31.000us       0.000us        176223  
                                            aten::copy_         3.17%     121.922ms         5.00%     192.307ms       3.950us      77.508ms        14.03%      84.659ms       1.739us         48689  
                                              aten::add         2.94%     112.852ms         3.71%     142.606ms       5.441us      52.169ms         9.45%      55.625ms       2.122us         26210  
                                        aten::embedding         0.06%       2.451ms         3.65%     140.132ms     198.207us       0.000us         0.00%     814.000us       1.151us           707  
                                     aten::index_select         0.12%       4.536ms         3.58%     137.528ms     194.523us     707.000us         0.13%     814.000us       1.151us           707  
                                              aten::cat         2.08%      80.087ms         2.73%     105.066ms       6.200us      28.032ms         5.08%      30.345ms       1.791us         16946  
                                    aten::empty_strided         2.62%     100.780ms         2.63%     101.032ms       1.564us       0.000us         0.00%       0.000us       0.000us         64601  
                                           aten::conv1d         0.00%      34.000us         2.14%      82.240ms      11.749ms       0.000us         0.00%       1.630ms     232.857us             7  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 3.844s
Self CUDA time total: 552.302ms

New (static kv cache alloc):

                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference        12.55%     532.084ms       100.00%        4.240s        4.240s       0.000us         0.00%        2.158s        2.158s             1  
                     aten::scaled_dot_product_attention         4.92%     208.726ms        20.90%     886.116ms      42.168us       0.000us         0.00%        1.766s      84.046us         21014  
          aten::_scaled_dot_product_efficient_attention         1.07%      45.356ms         3.65%     154.676ms      18.520us       0.000us         0.00%        1.508s     180.577us          8352  
                     aten::_efficient_attention_forward         1.62%      68.620ms         2.39%     101.249ms      12.123us        1.501s        74.52%        1.508s     180.577us          8352  
fmha_cutlassF_bf16_aligned_64x64_rf_sm80(PyTorchMemE...         0.00%       0.000us         0.00%       0.000us       0.000us        1.501s        74.52%        1.501s     179.697us          8352  
                                           aten::linear         7.57%     321.151ms        45.15%        1.915s      16.325us       0.000us         0.00%     403.876ms       3.444us        117273  
                                            aten::addmm         9.18%     389.373ms        11.80%     500.306ms       8.538us     151.665ms         7.53%     179.207ms       3.058us         58596  
std::enable_if<!(false), void>::type internal::gemvx...         0.00%       0.000us         0.00%       0.000us       0.000us     166.042ms         8.24%     166.042ms       2.485us         66816  
                                            aten::copy_         4.10%     173.775ms         6.45%     273.383ms       4.226us      96.918ms         4.81%     128.812ms       1.991us         64689  
                                       cudaLaunchKernel        13.76%     583.412ms        13.76%     583.412ms       2.221us     122.733ms         6.09%     123.738ms       0.471us        262688  
                                         aten::_to_copy         3.10%     131.269ms        11.23%     476.347ms       7.489us       0.000us         0.00%     113.743ms       1.788us         63604  
                                               aten::to         1.92%      81.521ms        11.83%     501.653ms       7.641us       0.000us         0.00%     102.061ms       1.555us         65650  
              aten::_scaled_dot_product_flash_attention         1.05%      44.456ms         6.06%     256.992ms      30.682us       0.000us         0.00%      97.419ms      11.631us          8376  
                         aten::_flash_attention_forward         2.82%     119.474ms         4.82%     204.445ms      24.408us      84.567ms         4.20%      97.419ms      11.631us          8376  
                                       aten::layer_norm         2.35%      99.721ms        12.65%     536.588ms      14.477us       0.000us         0.00%      79.579ms       2.147us         37066  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      67.887ms         3.37%      67.887ms       1.556us         43630  
                                 aten::_index_put_impl_         1.91%      80.812ms         5.76%     244.367ms      14.629us      33.707ms         1.67%      67.027ms       4.013us         16704  
void pytorch_flash::flash_fwd_splitkv_kernel<pytorch...         0.00%       0.000us         0.00%       0.000us       0.000us      66.733ms         3.31%      66.733ms       7.990us          8352  
                                aten::native_layer_norm         4.88%     206.966ms         8.06%     341.659ms      13.242us      51.754ms         2.57%      63.873ms       2.476us         25802  
                                              aten::add         2.56%     108.485ms         3.25%     137.698ms       5.337us      50.706ms         2.52%      62.488ms       2.422us         25803  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 4.240s
Self CUDA time total: 2.014s

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference        12.55%     532.084ms       100.00%        4.240s        4.240s       0.000us         0.00%        2.158s        2.158s             1  
                                           aten::linear         7.57%     321.151ms        45.15%        1.915s      16.325us       0.000us         0.00%     403.876ms       3.444us        117273  
                     aten::scaled_dot_product_attention         4.92%     208.726ms        20.90%     886.116ms      42.168us       0.000us         0.00%        1.766s      84.046us         21014  
                                       cudaLaunchKernel        13.76%     583.412ms        13.76%     583.412ms       2.221us     122.733ms         6.09%     123.738ms       0.471us        262688  
                                       aten::layer_norm         2.35%      99.721ms        12.65%     536.588ms      14.477us       0.000us         0.00%      79.579ms       2.147us         37066  
                                               aten::to         1.92%      81.521ms        11.83%     501.653ms       7.641us       0.000us         0.00%     102.061ms       1.555us         65650  
                                            aten::addmm         9.18%     389.373ms        11.80%     500.306ms       8.538us     151.665ms         7.53%     179.207ms       3.058us         58596  
                                         aten::_to_copy         3.10%     131.269ms        11.23%     476.347ms       7.489us       0.000us         0.00%     113.743ms       1.788us         63604  
                                aten::native_layer_norm         4.88%     206.966ms         8.06%     341.659ms      13.242us      51.754ms         2.57%      63.873ms       2.476us         25802  
                                            aten::copy_         4.10%     173.775ms         6.45%     273.383ms       4.226us      96.918ms         4.81%     128.812ms       1.991us         64689  
              aten::_scaled_dot_product_flash_attention         1.05%      44.456ms         6.06%     256.992ms      30.682us       0.000us         0.00%      97.419ms      11.631us          8376  
                                       aten::index_put_         1.25%      52.870ms         6.01%     255.001ms      15.266us       0.000us         0.00%      55.063ms       3.296us         16704  
                                           aten::matmul         0.63%      26.635ms         5.78%     244.892ms      25.035us       0.000us         0.00%      38.160ms       3.901us          9782  
                                 aten::_index_put_impl_         1.91%      80.812ms         5.76%     244.367ms      14.629us      33.707ms         1.67%      67.027ms       4.013us         16704  
                         aten::_flash_attention_forward         2.82%     119.474ms         4.82%     204.445ms      24.408us      84.567ms         4.20%      97.419ms      11.631us          8376  
                                        aten::embedding         0.06%       2.429ms         4.35%     184.347ms     264.866us       0.000us         0.00%       1.301ms       1.869us           696  
                                     aten::index_select         0.10%       4.323ms         4.29%     181.825ms     261.243us     697.000us         0.03%       1.299ms       1.866us           696  
                                            aten::empty         4.16%     176.200ms         4.18%     177.206ms       1.052us       0.000us         0.00%      40.000us       0.000us        168404  
          aten::_scaled_dot_product_efficient_attention         1.07%      45.356ms         3.65%     154.676ms      18.520us       0.000us         0.00%        1.508s     180.577us          8352  
                                      aten::logical_not         1.20%      50.999ms         3.35%     141.861ms       8.493us     197.000us         0.01%       6.395ms       0.383us         16704  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 4.240s
Self CUDA time total: 2.014s```

If anyone else has this issue in the future, the problem seemed to be fixed by adding the following context manager around my inference code:

            with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
                logits, next_tok_ids = decode_token(
                    model,
                    x=seq[:, idx - 1 : idx],
                    input_pos=torch.tensor(
                        [idx - 1], device=seq.device, dtype=torch.int
                    ),
                )

I have a feeling that under the hood, this context manager allows the attention calculation to ignore the masked positions in the static kv cache. It took me a couple days to fix this issue, maybe this should be made clearer in the gpt-fast blogpost.

So the context manager you added was helpful because bs=1 inference means your Q * K operation is a vector to matrix multiplication and not a matrix to matrix multiplication and turns out inductor does a better job codegening that than flash attention

If you run your code with TORCH_LOGS="output_code" you can see exactly the kernel that gets generated to help

Hi!

I actually have the same issue when using batch_size > 1. Using this context manager gives a 2-3x increase in inference speeds. Looking at the gpt-fast repo, they use this context manager too. I would assume that flash attention would be faster ideally, but I can live with the current speed using this context manager. I actually tried using flash attention with a smaller max_seq_length than 4096, and found that in this case the speeds are comparable to using 4096 with this context manager.

Thanks for your help : )