How to optimize ALiBi with transformer implementation?

Hi!

I wanted to try what would ALiBi/FIRE on Karpathy’s GPT-2 implementation (I removed almost all comments for the clarity of this post, I also removed the from_pretrained method), so I introduced the small following changes:

from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x, attn_mask):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        
        y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu    = nn.GELU(approximate='tanh')
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x, attn_mask):
        x = x + self.attn(self.ln_1(x), attn_mask)
        x = x + self.mlp(self.ln_2(x))
        return x

@dataclass
class GPTConfig:
    block_size: int = 1024 # max sequence length
    vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
    n_layer: int = 12 # number of layers
    n_head: int = 12 # number of heads
    n_embd: int = 768 # embedding dimension

class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.register_buffer('attn_mask', None)

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # weight sharing scheme
        self.transformer.wte.weight = self.lm_head.weight

        # init params
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            if hasattr(module, 'NANOGPT_SCALE_INIT'):
                std *= (2 * self.config.n_layer) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        # idx is of shape (B, T)
        B, T = idx.size()

        tok_emb = self.transformer.wte(idx)
        x = tok_emb
        if self.attn_mask is None or self.attn_mask.size(1) != T:
            # The ALiBi values being the same for all heads and all layers we can cache them
            # We only recompute the mask when we encounter a new sequence length because the ALiBi
            # are relative so we have to compute these new biases 
            self.attn_mask = self.generate_causal_linear_bias_mask(T, self.config.n_head).to(idx.device) # ALiBi + Causal Mask
        for block in self.transformer.h:
            x = block(x, self.attn_mask)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

    def configure_optimizers(self, weight_decay, learning_rate, device_type):
        # start with all of the candidate parameters (that require grad)
        param_dict = {pn: p for pn, p in self.named_parameters()}
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        if master_process:
            print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
            print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == "cuda"
        if master_process:
            print(f"using fused AdamW: {use_fused}")
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
        return optimizer

    @staticmethod
    def generate_causal_linear_bias_mask(T, n_head):
        causal_mask = torch.tril(torch.ones(T, T)).unsqueeze(0)
        causal_mask = causal_mask.masked_fill(causal_mask == 0, float("-inf"))
        causal_mask += -1

        start_slope = 2 ** (-8 / n_head)
        slopes = torch.tensor([start_slope ** i for i in range(1, n_head + 1)]).float().unsqueeze(-1).unsqueeze(-1)
        indices = torch.arange(T).unsqueeze(0) - torch.arange(T).unsqueeze(1)
        indices = -indices.abs().float()
        indices = indices.unsqueeze(0).repeat(n_head, 1, 1)
        linear_biases = slopes * indices
        attn_mask = (causal_mask + linear_biases).unsqueeze(0)

        return attn_mask

I’d like to optimize it but I’m not sure how. Here I’m caching the ALiBi biases as long as the sequence length is the same as before (which is generally the case during almost all training except when evaluating). How I generate the mask doesn’t look to me that funky.

It’s a bit slow, I guess it’s because we don’t benefit from the flash attention backend when using attn_mask with torch.nn.functional.scaled_dot_product_attention.
But at the same time I’m thinking that since this is just a bias and an addition maybe it’s not really the issue and the issue lies with how I cache the mask or its access or maybe its creation?

I put the category as “torch.compile” because the model is compiled and it is still really slow. For comparison, I get half throughput compared to the original implementation which uses F.scaled_dot_product_attention(q, k, v, is_causal=True) (on the same hardware)

(I tried alibi-slopes from the “flash-attention” package but I couldn’t get it to work with the rest of the code)

It’s possible FlexAttention will help, if your attention customization is representable here.