F.scaled_dot_product_attention get query @ key

Hello, everyone, I want to know how to get the query @ key in F.scaled_dot_product_attention, I use the below code but still got OOM, I can use the F.scaled_dot_product attention and don’t get the OOM, plz help…
or how to change the source code of F.scaled_dot_product_attention

def chunk_dot_product(query, key, num_chunks=2000):
    # query, key shape: [batch_size, num_heads, seq_len, head_dim]
    batch_size, num_heads, seq_len, head_dim = query.shape
    chunk_size = seq_len // num_chunks
    
    # 初始化輸出張量列表
    attn_chunks = []
    
    for i in range(num_chunks):
        chunk_weights = []
        # 取出當前 query chunk: [batch_size, num_heads, chunk_size, head_dim]
        q_chunk = query[:, :, i*chunk_size:(i+1)*chunk_size]
        
        # 對 key 也進行分塊處理
        for j in range(num_chunks):
            k_chunk = key[:, :, j*chunk_size:(j+1)*chunk_size]
            
            # 計算部分注意力權重
            # [batch_size, num_heads, chunk_size, chunk_size]
            chunk_attn = torch.matmul(q_chunk, k_chunk.transpose(-1, -2))
            chunk_weights.append(chunk_attn)
            
            # 適時清理記憶體
            if j < num_chunks - 1:  # 最後一個 chunk 不需要清理
                del k_chunk
                torch.cuda.empty_cache()
        
        # 在序列長度維度上連接: [batch_size, num_heads, chunk_size, seq_len]
        row_weights = torch.cat(chunk_weights, dim=-1)
        attn_chunks.append(row_weights)
        
        # 清除中間結果
        del chunk_weights
        del q_chunk
        torch.cuda.empty_cache()
    
    # 最後將所有塊組合起來: [batch_size, num_heads, seq_len, seq_len]
    attn_weight = torch.cat(attn_chunks, dim=2)
    return attn_weight

# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = chunk_dot_product(query, key) * scale_factor
    attn_weight += attn_bias
    # after the lora masking do the softmax
    # attn_weight = torch.softmax(attn_weight, dim=-1)
    # attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    
    return attn_weight