When performing the prefix sum operation, it appears that there is an abnormal time overhead. Subsequent operations on the calculated results of the prefix sum incur significantly high additional time costs. I am unsure about the reason of this thing and how to resolve the issue.
code:
binAns = torch.bincount(dstList)
cumTime = time.time()
ptrcum = torch.cumsum(binAns, dim=0)
print("dstList:" ,dstList.shape)
print("binAns:" ,binAns.shape)
print(f"using time : {time.time()-cumTime:.4f}s...")
cumTime = time.time()
zeroblock=torch.zeros(1,device=binAns.device)
print(f"zeroblock using time : {time.time()-cumTime:.4f}s...")
cumTime = time.time()
inptr = torch.cat([zeroblock,ptrcum]).to(torch.int32)
print(f"cat using time : {time.time()-cumTime:.4f}s...")
output:
loading data time : 0.4238s
dstList: torch.Size([148151720])
binAns: torch.Size([9805926])
using time : 0.0002s...
zeroblock using time : 0.0000s...
cat using time : 0.0001s...
If I include a print operation:
binAns = torch.bincount(dstList)
cumTime = time.time()
ptrcum = torch.cumsum(binAns, dim=0)
print("dstList:" ,dstList.shape)
print("binAns:" ,binAns.shape)
print(f"using time : {time.time()-cumTime:.4f}s...")
cumTime = time.time()
zeroblock=torch.zeros(1,device=binAns.device)
print(f"zeroblock using time : {time.time()-cumTime:.4f}s...")
cumTime = time.time()
inptr = torch.cat([zeroblock,ptrcum]).to(torch.int32)
print(inptr) # additional print operation
print(f"cat using time : {time.time()-cumTime:.4f}s...")
output:
loading data time : 0.4285s
dstList: torch.Size([148151720])
binAns: torch.Size([9805926])
using time : 0.0002s...
zeroblock using time : 0.0000s...
tensor([ 0, 2, 4, ..., 148151712, 148151712,
148151712], device='cuda:0', dtype=torch.int32)
cat using time : 140.7660s...