Hello, everyone, I want to know how to get the query @ key in F.scaled_dot_product_attention, I use the below code but still got OOM, I can use the F.scaled_dot_product attention and don’t get the OOM, plz help…
or how to change the source code of F.scaled_dot_product_attention
def chunk_dot_product(query, key, num_chunks=2000):
# query, key shape: [batch_size, num_heads, seq_len, head_dim]
batch_size, num_heads, seq_len, head_dim = query.shape
chunk_size = seq_len // num_chunks
# 初始化輸出張量列表
attn_chunks = []
for i in range(num_chunks):
chunk_weights = []
# 取出當前 query chunk: [batch_size, num_heads, chunk_size, head_dim]
q_chunk = query[:, :, i*chunk_size:(i+1)*chunk_size]
# 對 key 也進行分塊處理
for j in range(num_chunks):
k_chunk = key[:, :, j*chunk_size:(j+1)*chunk_size]
# 計算部分注意力權重
# [batch_size, num_heads, chunk_size, chunk_size]
chunk_attn = torch.matmul(q_chunk, k_chunk.transpose(-1, -2))
chunk_weights.append(chunk_attn)
# 適時清理記憶體
if j < num_chunks - 1: # 最後一個 chunk 不需要清理
del k_chunk
torch.cuda.empty_cache()
# 在序列長度維度上連接: [batch_size, num_heads, chunk_size, seq_len]
row_weights = torch.cat(chunk_weights, dim=-1)
attn_chunks.append(row_weights)
# 清除中間結果
del chunk_weights
del q_chunk
torch.cuda.empty_cache()
# 最後將所有塊組合起來: [batch_size, num_heads, seq_len, seq_len]
attn_weight = torch.cat(attn_chunks, dim=2)
return attn_weight
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = chunk_dot_product(query, key) * scale_factor
attn_weight += attn_bias
# after the lora masking do the softmax
# attn_weight = torch.softmax(attn_weight, dim=-1)
# attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight