A single `print` of tensor could affect the result

Hello everyone,

I’ve encountered an unusual issue regarding the print function and CUDA tensors. When I don’t print a particular tensor on CUDA, I receive an all-zero result. However, when I do print it, I observe non-zero values. It’s worth noting that I’ve already added explicit torch.cuda.synchronize.

Considering that the print function shouldn’t modify tensor values, I find this behavior perplexing.

Would anyone have insights or suggestions on this matter?

For reference, I’ve provided a minimal code example below, in which the print(q) or not lead to a different result. Please keep in mind that while the function may seem arbitrary, my main concern is the inconsistent results I’m observing:
(Tested with python==3.10, torch==2.0.0, A100 and RTX3090)

# install the dep via 
# pip install flash-attn --no-build-isolation
from collections import namedtuple
import torch
from flash_attn import flash_attn_varlen_kvpacked_func

AttentionKernelConfig = namedtuple('config', ['batch_size', 'sequence_length', 'num_heads', 'head_size', 'dtype', 'causal'])

def prepare_shared_args(kernel_config):
    qkv = torch.randn(3, kernel_config.batch_size, kernel_config.sequence_length, kernel_config.num_heads * kernel_config.head_size, dtype=kernel_config.dtype, device="cuda")
    q = qkv[0]
    k = qkv[1]
    v = qkv[2]
    return tuple([q, k, v])

def prepare_unique_args(kernel_config):
    q, k, v = prepare_shared_args(kernel_config)
    bs = kernel_config.batch_size
    seqlen = kernel_config.sequence_length
    num_head = kernel_config.num_heads
    head_dim = kernel_config.head_size
    q = q.view(bs, -1, num_head, head_dim)[:, :1].reshape(bs, num_head, head_dim)
    kv = torch.cat([k, v], dim=0).view(bs, seqlen, 2, num_head, head_dim).reshape(bs * seqlen, 2, num_head, head_dim)
    cu_seqlens = torch.ones(bs + 1, dtype=torch.int32, device=q.device)
    max_seqlen = 1
    cu_seqlens_k = torch.ones(bs + 1, dtype=torch.int32, device=q.device) * seqlen
    max_seqlen_k = seqlen
    # print(q)
    return q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k

def prepare_unique_kwargs(kernel_config):
    return {
        'causal': kernel_config.causal
    }

attention_config = AttentionKernelConfig(4, 1024, 32, 128, torch.float16, False)
args = prepare_unique_args(attention_config)
kwargs = prepare_unique_kwargs(attention_config)
torch.cuda.synchronize()
result = flash_attn_varlen_kvpacked_func(*args, **kwargs)
torch.cuda.synchronize()
print(result)

Thank you in advance for your assistance!

The described issue sounds like a race condition caused by missing syncs. Are you using custom CUDA extensions? If so, do you still see this issue in any “pure” PyTorch code?

Thanks for the reply!
Yes, the script used custom CUDA extensions and I have not seen it in pure PyTorch code.
But I do not understand why the “race” condition would be solved by a single print. The data in the device should be the same.