import torch
from torch.utils.cpp_extension import load
wait_test = load(name=‘wait_flag’, sources=[‘cuda_flag_wait.cpp’, ‘cuda_flag_wait.cu’])
a_s = [torch.tensor([False]*128,dtype=torch.bool).cuda() for _ in range(20)]
w1 = torch.randn(4096,14336,dtype=torch.bfloat16).cuda()
x1 = torch.randn(1,4096,dtype=torch.bfloat16).cuda()
g = torch.cuda.CUDAGraph()
j=0
#warmup
for i in range(1):
torch.matmul(x1,w1)
wait_test.wait_flag(2000000,a_s[j])
torch.matmul(x1,w1)
# torch.cuda.synchronize()
j=0
with torch.cuda.graph(g):
torch.matmul(x1,w1)
wait_test.wait_flag(2000000,a_s[j])
torch.matmul(x1,w1)
g.replay()
###########
#include <torch/extension.h>
#include <torch/torch.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAEvent.h>
static at::cuda::CUDAStream captureStream = at::cuda::getStreamFromPool();
static at::cuda::CUDAEvent cuda_events;
void wait_test( int clock_count,bool * flag);//, cudaStream_t a
void set_flag(bool* flag);
void wait_flag(int clock_count,torch::Tensor device_modules)
{
wait_test(clock_count,device_modules.data_ptr());
// cuda_events.record(captureStream);
}
// Binding
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def(“wait_flag”, &wait_flag, “wait”);
}
#########
#include “cuda_runtime.h”
#include “device_launch_parameters.h”
#include <cub/cub.cuh>
#include <thrust/device_vector.h>
#include <thrust/sort.h>
#include <thrust/unique.h>
#include <string.h>
global void clock_block(clock_t clock_count, bool* flag)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
clock_t start_clock = clock();
clock_t clock_offset = 0;
// clock_t clock_count_tmp = 1000000;
//
int count = 0;
bool state_v = flag[0];
while ( (clock_offset < clock_count)&&(!state_v) )
{
clock_offset = clock() - start_clock;
state_v = flag[count];
count+=1;
count = count % 128;
}
}
void wait_test(int clock_count, bool* flag )
{
int m = 1;
int n = 1;
int block1D = 1;
dim3 block(block1D, block1D);
dim3 grid(m/block1D, n/block1D);
//execute the kernel
clock_block<<< grid, block>>>(clock_count,flag);
}