What I was considered is the performance degradation is from PCIE since there are 4 GPUs load and write data through zero-copy technique.
my python code is like:
import torch.multiprocessing as mp
import torch as th
import time
import sys
from torch.utils.cpp_extension import load
import numpy as np
zerocopy_cpp = load(name='testcase', sources=['zp_test.cu'], extra_cflags=['-I/usr/local/cuda/include'], extra_cuda_cflags=['-I/usr/local/cuda/include'], extra_ldflags=['-lcuda', '-ldl'])
class ZeroCopy(th.autograd.Function):
@staticmethod
def forward(ctx, emb, indices, device):
output = zerocopy_cpp.zero_copy_call(emb, indices, device)
return output
@staticmethod
def backward(ctx):
pass
class ZeroWrite(th.autograd.Function):
@staticmethod
def forward(ctx, emb, res, indices, device):
zerocopy_cpp.zero_write(emb, res, indices, device)
@staticmethod
def backward(ctx):
pass
class Pin_Mem(th.autograd.Function):
@staticmethod
def forward(ctx, emb):
zerocopy_cpp.pin_mem(emb)
return emb
@staticmethod
def backward(ctx):
pass
zero_copy = ZeroCopy.apply
zero_write = ZeroWrite.apply
pin_mem = Pin_Mem.apply
def train_mp(emb, rank):
pin_mem(emb)
copy_time=0
write_time=0
for i in range(1000):
indices = th.randint(0, 3000000, (10000,))
start = time.time()
data = zero_copy(emb, indices, rank)
copy_time+=time.time() - start
grad = data*0.1
start = time.time()
zero_write(emb, grad, indices, rank)
write_time += time.time() - start
print('copy time on {} is {}'.format(rank, copy_time))
print('write time on {} is {}'.format(rank, write_time))
def main():
mp.set_start_method('forkserver')
num_gpus = 4
data = th.rand((3000000, 100))
pin_mem(data)
procs = []
for i in range(num_gpus):
proc = mp.Process(target=train_mp, args=(data, i))
procs.append(proc)
proc.start()
for proc in procs:
proc.join()
if __name__ == '__main__':
main()
and my cuda code is like:
#include <torch/extension.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cstdint>
#include <iostream>
#include <bitset>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <errno.h>
#include <error.h>
#include <stdlib.h>
#include <sys/time.h>
using namespace std;
typedef unsigned __int128 uint128_t;
#define abort(ret, errno, ...) error_at_line(ret, errno, __FILE__, __LINE__, \
__VA_ARGS__)
#define CEIL(a, b) (((a)+(b)-1)/(b))
#define CHECK(call) \
{ \
const cudaError_t error = call; \
if (error != cudaSuccess) \
{ \
fprintf(stderr, "Error: %s:%d, ", __FILE__, __LINE__); \
fprintf(stderr, "code: %d, reason: %s\n", error, \
cudaGetErrorString(error)); \
} \
}
__global__ void index_kernel(float *res, long *indices, float *src, int upper_bound, int dim)
{
const int idx = blockIdx.x * blockDim.y + threadIdx.y;
if (idx < upper_bound){
for(int i=threadIdx.x; i<dim; i+=blockDim.x){
res[idx * dim + i] = src[indices[idx] * dim + i];
}
}
}
torch::Tensor zero_copyH2D(torch::Tensor emb, torch::Tensor indices, int dev_id) {
cudaSetDevice(dev_id);
dim3 block(32, 32);
dim3 grids = (CEIL(indices.size(0), block.y));
dim3 grids_vec = (CEIL(indices.size(0), block.y*block.x));
torch::Device dev = indices.device();
long * idx;
CHECK(cudaMalloc(&idx, sizeof(long) * indices.size(0)));
CHECK(cudaMemcpy(idx, indices.data_ptr<long>(), sizeof(long) * indices.size(0), cudaMemcpyHostToDevice));
torch::Tensor res = torch::empty({indices.size(0), emb.size(1)}, torch::TensorOptions(torch::kFloat32).device(torch::kCUDA, dev_id));
index_kernel<<< grids, block, 0 >>>(res.data_ptr<float>(), idx, emb.data_ptr<float>(), indices.size(0), emb.size(1));
CHECK(cudaFree(idx));
cudaDeviceSynchronize();return res;
}
__global__ void write_kernel(float *emb, long *indices, float *res, int upper_bound, int dim)
{
const int idx = blockIdx.x * blockDim.y + threadIdx.y;
if (idx < upper_bound){
for(int i=threadIdx.x; i<dim; i+=blockDim.x){
emb[indices[idx] * dim + i] += res[idx * dim + i];
}
}
}
void zero_writeD2H(torch::Tensor emb, torch::Tensor res, torch::Tensor indices, int dev_id){
cudaSetDevice(dev_id);
dim3 block(32, 32);
dim3 grids = (CEIL(indices.size(0), block.y));
torch::Device dev = indices.device();
long * idx;
CHECK(cudaMalloc(&idx, sizeof(long) * indices.size(0)));
CHECK(cudaMemcpy(idx, indices.data_ptr<long>(), sizeof(long) * indices.size(0), cudaMemcpyHostToDevice));
write_kernel<<< grids, block, 0 >>>(emb.data_ptr<float>(), idx, res.data_ptr<float>(), indices.size(0), emb.size(1));
cudaDeviceSynchronize();
CHECK(cudaFree(idx));
}
void pin_mem(torch::Tensor emb){
CHECK(cudaHostRegister(emb.data_ptr<float>(), sizeof(float) * emb.size(0)*emb.size(1), cudaHostRegisterPortable| cudaHostAllocMapped));
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("zero_copy_call", &zero_copyH2D, "zero copy data read from cpu to gpu");
m.def("zero_write", &zero_writeD2H, "zero copy data read from cpu to gpu");
m.def("pin_mem", &pin_mem, "pin memory on CPU");
}