I am developing a Transformer model that utilizes sparse attention to handle long sequences efficiently. In my approach, each query token attends to only a few specified key and value tokens. My implementation involves using the gather function to select the relevant K and V vectors before computing the attention scores with Q.
While the forward pass executes without issues, the backward pass results in an exceptionally high memory usage (28610.23 GiB), leading to CUDA out-of-memory errors. I suspect that the gather operation is not effectively reducing memory consumption as intended.
Below is a simplified version of my code (without batch and head dimensions):
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class Transformer(nn.Module):
def __init__(self, d_model, num_labels, d_ff=2048, dropout=0.1):
super(Transformer, self).__init__()
self.m_q = nn.Linear(d_model, d_model)
self.m_k = nn.Linear(d_model, d_model)
self.m_v = nn.Linear(d_model, d_model)
self.linear1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ff, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.classifier = nn.Linear(d_model, num_labels)
def forward(self, x, attn_pattern):
q = self.m_q(x)
k = self.m_k(x)
v = self.m_v(x)
# Sparse attention
attn_output = self.sparse_attn(q, k, v, attn_pattern)
x = self.norm1(x + attn_output)
# Feed-forward
ff_output = self.linear2(self.dropout(F.relu(self.linear1(x))))
x = self.norm2(x + ff_output)
x = self.classifier(x)
x = F.softmax(x, dim=-1)
return x
def sparse_attn(self, Q, K, V, attn_pattern):
'''
`attn_pattern` is a tensor of shape (q_len, n_attn)
containing the indices of the attending tokens for each query token
'''
q_len, d_model = Q.shape
# Expand attention indices to match the embedding dimension
attn_pattern = attn_pattern.unsqueeze(-1).expand(-1, -1, d_model)
# Gather the relevant K and V vectors
K_selected = torch.gather(K.unsqueeze(0).expand(q_len, -1, -1), 1, attn_pattern)
V_selected = torch.gather(V.unsqueeze(0).expand(q_len, -1, -1), 1, attn_pattern)
# Compute attention scores
scores = torch.einsum('qd,qnd->qn', Q, K_selected) / math.sqrt(d_model)
attn_weights = F.softmax(scores, dim=-1)
# Compute attention output
attn = torch.einsum('qn,qnd->qd', attn_weights, V_selected)
return attn
if __name__ == '__main__':
seq_len = 100000 # Long sequence
d_model = 768
num_labels = 1000
input_seq = torch.rand(seq_len, d_model).to('cuda')
attn_pattern = torch.randint(0, seq_len, (seq_len, 8)).to('cuda') # Each Q token attends to exactly 8 tokens
labels = torch.randint(0, num_labels, (seq_len,)).to('cuda') # Fake labels
model = Transformer(d_model, num_labels).to('cuda')
loss_fn = torch.nn.CrossEntropyLoss()
output = model(input_seq, attn_pattern)
loss = loss_fn(output, labels)
loss.backward() # CUDA out of memory. Tried to allocate 28610.23 GiB.
print('Finish.')
During the backward pass, the memory consumption spikes dramatically, causing CUDA out-of-memory errors.
What is the correct way to implement sparse attention in a Transformer model to minimize memory usage, especially during the backward pass?
I am currently reviewing the source code of BigBird’s attention mechanism but haven’t fully understood it yet. What are the key differences between BigBird’s implementation and mine?