Hello everyone,
I’ve encountered an unusual issue regarding the print
function and CUDA tensors. When I don’t print a particular tensor on CUDA, I receive an all-zero result. However, when I do print it, I observe non-zero values. It’s worth noting that I’ve already added explicit torch.cuda.synchronize
.
Considering that the print
function shouldn’t modify tensor values, I find this behavior perplexing.
Would anyone have insights or suggestions on this matter?
For reference, I’ve provided a minimal code example below, in which the print(q)
or not lead to a different result. Please keep in mind that while the function may seem arbitrary, my main concern is the inconsistent results I’m observing:
(Tested with python==3.10, torch==2.0.0, A100 and RTX3090)
# install the dep via
# pip install flash-attn --no-build-isolation
from collections import namedtuple
import torch
from flash_attn import flash_attn_varlen_kvpacked_func
AttentionKernelConfig = namedtuple('config', ['batch_size', 'sequence_length', 'num_heads', 'head_size', 'dtype', 'causal'])
def prepare_shared_args(kernel_config):
qkv = torch.randn(3, kernel_config.batch_size, kernel_config.sequence_length, kernel_config.num_heads * kernel_config.head_size, dtype=kernel_config.dtype, device="cuda")
q = qkv[0]
k = qkv[1]
v = qkv[2]
return tuple([q, k, v])
def prepare_unique_args(kernel_config):
q, k, v = prepare_shared_args(kernel_config)
bs = kernel_config.batch_size
seqlen = kernel_config.sequence_length
num_head = kernel_config.num_heads
head_dim = kernel_config.head_size
q = q.view(bs, -1, num_head, head_dim)[:, :1].reshape(bs, num_head, head_dim)
kv = torch.cat([k, v], dim=0).view(bs, seqlen, 2, num_head, head_dim).reshape(bs * seqlen, 2, num_head, head_dim)
cu_seqlens = torch.ones(bs + 1, dtype=torch.int32, device=q.device)
max_seqlen = 1
cu_seqlens_k = torch.ones(bs + 1, dtype=torch.int32, device=q.device) * seqlen
max_seqlen_k = seqlen
# print(q)
return q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k
def prepare_unique_kwargs(kernel_config):
return {
'causal': kernel_config.causal
}
attention_config = AttentionKernelConfig(4, 1024, 32, 128, torch.float16, False)
args = prepare_unique_args(attention_config)
kwargs = prepare_unique_kwargs(attention_config)
torch.cuda.synchronize()
result = flash_attn_varlen_kvpacked_func(*args, **kwargs)
torch.cuda.synchronize()
print(result)
Thank you in advance for your assistance!