I’m trying to reimplement Whisper inference with the performance tricks inspired by gpt-fast:
I’m going to get round to implementing int8 quantisation for the linear layers. However as a first try I wanted to see how far I could get with cuda graphs and torch.compile, as they describe in the blogpost. The main issue I’m having is that the ‘improved’ cuda-graph friendly kv cache implementation in gpt-fast seems to be significantly slowing down my inference compared to my initial naive implementation. I think I’m seriously misunderstanding something.
I’ve taken OpenAI’s whisper implementation (whisper/whisper/model.py at main · openai/whisper · GitHub) and have modified is so that it is cuda-graph friendly, i.e torch.compile(mode=“reduce-overhead”, fullgraph=True) doesn’t throw an error. I reimplemented the kv cache using the gpt-fast method as well as a cuda-graph unfriendly dynamic memory allocation method. These are both in the following code:
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from dataclasses import dataclass
from typing import Dict, Iterable, Optional
@dataclass
class ModelConfig:
n_mels: int
n_audio_ctx: int
n_audio_state: int
n_audio_head: int
n_audio_layer: int
n_text_ctx: int
n_text_state: int
n_text_head: int
n_text_layer: int
n_vocab: Optional[int] = None
def set_vocab_size(self, vocab_size: int):
self.n_vocab = vocab_size
class KVCache(nn.Module):
def __init__(
self,
max_batch_size: int,
max_seq_length: int,
n_heads: int,
head_dim: int,
dtype=torch.bfloat16,
):
super().__init__()
# New
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
# End New
# Old
# self.k_cache = None
# self.v_cache = None
# End Old
def get_cache(self):
# Only intended to be used for cross attention
return self.k_cache, self.v_cache
def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val, v_val: [B, H, L, D]
k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
return k_out, v_out
def update_old(self, input_pos, k_val, v_val):
if self.k_cache is None:
self.k_cache, self.v_cache = k_val, v_val
else:
k_val = torch.cat((self.k_cache, k_val), dim=2).detach()
v_val = torch.cat((self.v_cache, v_val), dim=2).detach()
self.k_cache, self.v_cache = k_val, v_val
return k_val, v_val
def sinusoids(
length: int, channels: int, max_timescale: float = 10000
) -> torch.Tensor:
"""Returns sinusoids for positional embedding"""
if channels % 2 != 0:
raise ValueError(
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
)
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(
-log_timescale_increment * torch.arange(channels // 2)
)
scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
class EncoderAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
assert n_state % n_head == 0, "n_head does not evenly devide n_state"
self.n_head = n_head
self.d_head = n_state // n_head
self.query = nn.Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)
def forward(
self,
xa: Tensor,
):
q = self.query(xa)
k = self.key(xa)
v = self.value(xa)
# Reshape for correct format
batch_size, source_seq_len, _ = k.shape
batch_size, target_seq_len, _ = q.shape
q = q.view(
batch_size, target_seq_len, self.n_head, self.d_head
).transpose(1, 2)
k = k.view(
batch_size, source_seq_len, self.n_head, self.d_head
).transpose(1, 2)
v = v.view(
batch_size, source_seq_len, self.n_head, self.d_head
).transpose(1, 2)
wv = F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
is_causal=False,
)
wv = wv.transpose(1, 2).view(
batch_size,
target_seq_len,
self.n_head * self.d_head,
)
return self.out(wv)
class CrossAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
assert n_state % n_head == 0, "n_head does not evenly devide n_state"
self.n_head = n_head
self.d_head = n_state // n_head
self.query = nn.Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)
self.kv_cache: KVCache | None = None
def prefill_kv(self, xa: torch.Tensor):
assert self.kv_cache is not None, "No kv_cache"
k = self.key(xa)
v = self.value(xa)
# Reshape for correct format
batch_size, source_seq_len, _ = k.shape
k = k.view(
batch_size, source_seq_len, self.n_head, self.d_head
).transpose(1, 2)
v = v.view(
batch_size, source_seq_len, self.n_head, self.d_head
).transpose(1, 2)
self.kv_cache.k_cache = k
self.kv_cache.v_cache = v
return k, v
def forward(
self,
x: Tensor,
):
q = self.query(x)
batch_size, target_seq_len, _ = q.shape
q = q.view(
batch_size, target_seq_len, self.n_head, self.d_head
).transpose(1, 2)
k, v = self.kv_cache.get_cache()
wv = F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
is_causal=False,
)
wv = wv.transpose(1, 2).view(
batch_size,
target_seq_len,
self.n_head * self.d_head,
)
return self.out(wv)
class CausalSelfAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
assert n_state % n_head == 0, "n_head does not evenly devide n_state"
self.n_head = n_head
self.d_head = n_state // n_head
self.query = nn.Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)
self.kv_cache: KVCache | None = None
def get_kv(self, x: torch.Tensor, input_pos: torch.Tensor):
# Self attn
k = self.key(x)
v = self.value(x)
# Reshape
batch_size, source_seq_len, _ = k.shape
k = k.view(
batch_size, source_seq_len, self.n_head, self.d_head
).transpose(1, 2)
v = v.view(
batch_size, source_seq_len, self.n_head, self.d_head
).transpose(1, 2)
# New
k, v = self.kv_cache.update(k_val=k, v_val=v, input_pos=input_pos)
# End New
# Old
# k, v = self.kv_cache.update_old(k_val=k, v_val=v, input_pos=input_pos)
# End Old
return k, v
def forward(
self,
x: Tensor,
mask: Optional[Tensor] = None,
input_pos: Optional[Tensor] = None,
):
q = self.query(x)
batch_size, target_seq_len, _ = q.shape
q = q.view(
batch_size, target_seq_len, self.n_head, self.d_head
).transpose(1, 2)
k, v = self.get_kv(x, input_pos=input_pos)
# New
wv = F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=mask,
)
# End New
# Old
# wv = F.scaled_dot_product_attention(
# query=q,
# key=k,
# value=v,
# is_causal=False, # This fine since we never prefill
# )
# End Old
# (bz, nh, L, dh) -> (bz, L, nh, dh) -> (bz, L, d)
wv = wv.transpose(1, 2).view(
batch_size, target_seq_len, self.n_head * self.d_head
)
return self.out(wv)
class EncoderAttentionBlock(nn.Module):
def __init__(
self, n_state: int, n_head: int, cross_attention: bool = False
):
super().__init__()
self.attn = EncoderAttention(n_state, n_head)
self.attn_ln = nn.LayerNorm(n_state)
n_mlp = n_state * 4
self.mlp = nn.Sequential(
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
)
self.mlp_ln = nn.LayerNorm(n_state)
def forward(
self,
xa: Tensor,
):
xa = xa + self.attn(
self.attn_ln(xa),
)
xa = xa + self.mlp(self.mlp_ln(xa))
return xa
class DecoderAttentionBlock(nn.Module):
def __init__(
self, n_state: int, n_head: int, cross_attention: bool = False
):
super().__init__()
self.attn = CausalSelfAttention(n_state, n_head)
self.attn_ln = nn.LayerNorm(n_state)
self.cross_attn = (
CrossAttention(n_state, n_head) if cross_attention else None
)
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
)
self.mlp_ln = nn.LayerNorm(n_state)
def forward(
self,
x: Tensor,
mask: Optional[Tensor] = None,
input_pos: Optional[Tensor] = None,
):
x = x + self.attn(
self.attn_ln(x),
mask=mask,
input_pos=input_pos,
)
x = x + self.cross_attn(self.cross_attn_ln(x))
x = x + self.mlp(self.mlp_ln(x))
return x
class AudioEncoder(nn.Module):
def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(
n_state, n_state, kernel_size=3, stride=2, padding=1
)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks: Iterable[EncoderAttentionBlock] = nn.ModuleList(
[EncoderAttentionBlock(n_state, n_head) for _ in range(n_layer)]
)
self.ln_post = nn.LayerNorm(n_state)
def forward(self, xa: Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
xa = F.gelu(self.conv1(xa))
xa = F.gelu(self.conv2(xa))
xa = xa.permute(0, 2, 1)
assert (
xa.shape[1:] == self.positional_embedding.shape
), f"incorrect audio shape: {xa.shape[1:]} != {self.positional_embedding.shape}"
xa = (xa + self.positional_embedding).to(xa.dtype)
for block in self.blocks:
xa = block(xa)
xa = self.ln_post(xa)
return xa
class TextDecoder(nn.Module):
def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[DecoderAttentionBlock] = nn.ModuleList(
[
DecoderAttentionBlock(n_state, n_head, cross_attention=True)
for _ in range(n_layer)
]
)
self.ln = nn.LayerNorm(n_state)
self.register_buffer("causal_mask", None, persistent=False)
# Needs to be adjusted for train
def forward(
self,
x: Tensor,
input_pos: Tensor | None = None,
):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
"""
# Works for batched inference
mask = self.causal_mask[None, None, input_pos]
x = self.token_embedding(x) + self.positional_embedding[input_pos]
for block in self.blocks:
x = block(x, mask=mask, input_pos=input_pos)
x = self.ln(x)
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return logits
def setup_cache(
self,
xa: Tensor,
batch_size,
max_seq_len=4096,
max_audio_len=1500,
):
# Resets prefills kv_cache
self.causal_mask = torch.tril(
torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
)
# Init cache
for b in self.blocks:
b.attn.kv_cache = KVCache(
max_batch_size=batch_size,
max_seq_length=max_seq_len,
n_heads=8,
head_dim=64,
).cuda()
b.cross_attn.kv_cache = KVCache(
max_batch_size=batch_size,
max_seq_length=max_audio_len,
n_heads=8,
head_dim=64,
).cuda()
b.cross_attn.prefill_kv(xa)
class AmtEncoderDecoder(nn.Module):
def __init__(self, dims: ModelConfig):
super().__init__()
self.dims = dims
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
)
@property
def device(self):
return next(self.parameters()).device
And my inference code is just doing greedy decoding, something like
audio_features = model.encoder(xa=log_mels)
model.decoder.setup_cache(xa=audio_features, batch_size=seq.shape[0])
model.cuda()
for idx in (
pbar := tqdm(
range(min_prefix_len, MAX_SEQ_LEN - 1),
total=MAX_SEQ_LEN - (min_prefix_len + 1),
leave=False,
)
):
if idx == min_prefix_len:
logits = model.decoder(
x=seq[:, :idx],
input_pos=torch.arange(0, idx, device=seq.device),
)
else:
logits = model.decoder(
x=seq[:, idx - 1 : idx],
input_pos=torch.tensor(
[idx], device=seq.device, dtype=torch.int
),
)
seq[:, idx] = torch.argmax(logits[:, -1], dim=-1)
When using the dynamic implementation without torch.compile(), I can get around 450+ tok/s on my 4090 with batch_size=1. When using my implementation of the static implementation, I get around 290 toks/s without torch.compile() and 320 with torch.compile(mode"reduce-overhead", fullgraph=True).
I’ve obviously implemented something wrong, however I’m having a hard time tracking it down. I’m pretty sure the issue is either that the attention calculation is not realising that it can skip a bunch of the computation due to the attention mask, or that updating and returning the statically allocated kv takes longer than using torch.cat.
Any advice would be appraised.