How to Reduce Memory Usage When Implementing Sparse Attention in a Transformer Model?

I am developing a Transformer model that utilizes sparse attention to handle long sequences efficiently. In my approach, each query token attends to only a few specified key and value tokens. My implementation involves using the gather function to select the relevant K and V vectors before computing the attention scores with Q.

While the forward pass executes without issues, the backward pass results in an exceptionally high memory usage (28610.23 GiB), leading to CUDA out-of-memory errors. I suspect that the gather operation is not effectively reducing memory consumption as intended.

Below is a simplified version of my code (without batch and head dimensions):

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class Transformer(nn.Module):
    def __init__(self, d_model, num_labels, d_ff=2048, dropout=0.1):
        super(Transformer, self).__init__()
        self.m_q = nn.Linear(d_model, d_model)
        self.m_k = nn.Linear(d_model, d_model)
        self.m_v = nn.Linear(d_model, d_model)
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, num_labels)
    
    def forward(self, x, attn_pattern):
        q = self.m_q(x)
        k = self.m_k(x)
        v = self.m_v(x)
        # Sparse attention
        attn_output = self.sparse_attn(q, k, v, attn_pattern)
        x = self.norm1(x + attn_output)
        # Feed-forward
        ff_output = self.linear2(self.dropout(F.relu(self.linear1(x))))
        x = self.norm2(x + ff_output)
        x = self.classifier(x)
        x = F.softmax(x, dim=-1)
        return x
    
    def sparse_attn(self, Q, K, V, attn_pattern):
        ''' 
        `attn_pattern` is a tensor of shape (q_len, n_attn) 
        containing the indices of the attending tokens for each query token 
        '''
        q_len, d_model = Q.shape
        # Expand attention indices to match the embedding dimension
        attn_pattern = attn_pattern.unsqueeze(-1).expand(-1, -1, d_model)
        # Gather the relevant K and V vectors
        K_selected = torch.gather(K.unsqueeze(0).expand(q_len, -1, -1), 1, attn_pattern)
        V_selected = torch.gather(V.unsqueeze(0).expand(q_len, -1, -1), 1, attn_pattern)
        # Compute attention scores
        scores = torch.einsum('qd,qnd->qn', Q, K_selected) / math.sqrt(d_model)
        attn_weights = F.softmax(scores, dim=-1)
        # Compute attention output
        attn = torch.einsum('qn,qnd->qd', attn_weights, V_selected)
        return attn

if __name__ == '__main__':

    seq_len = 100000 # Long sequence
    d_model = 768
    num_labels = 1000

    input_seq = torch.rand(seq_len, d_model).to('cuda')
    attn_pattern = torch.randint(0, seq_len, (seq_len, 8)).to('cuda') # Each Q token attends to exactly 8 tokens
    labels = torch.randint(0, num_labels, (seq_len,)).to('cuda') # Fake labels

    model = Transformer(d_model, num_labels).to('cuda')
    loss_fn = torch.nn.CrossEntropyLoss()

    output = model(input_seq, attn_pattern)
    loss = loss_fn(output, labels)
    loss.backward() # CUDA out of memory. Tried to allocate 28610.23 GiB.

    print('Finish.')

During the backward pass, the memory consumption spikes dramatically, causing CUDA out-of-memory errors.

What is the correct way to implement sparse attention in a Transformer model to minimize memory usage, especially during the backward pass?

I am currently reviewing the source code of BigBird’s attention mechanism but haven’t fully understood it yet. What are the key differences between BigBird’s implementation and mine?