Why didn't this cuda kernel implement?

I’m new about cuda programming but I need to make some pytorch cuda programming now.
So I looked the pytorch cuda extension tutorial and had a try. First I’d like to try the add example, so I tried:

// add_cuda.h
#ifndef _ADD_CUDA
#define _ADD_CUDA

#include <torch/extension.h>
void add(torch::Tensor a, torch::Tensor b, torch::Tensor c);
#endif
// add_wrapper.cpp
#include "add_cuda.h"

#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

void add(torch::Tensor a, torch::Tensor b, torch::Tensor c){
    CHECK_INPUT(a);
    CHECK_INPUT(b);
    CHECK_INPUT(c);
    add_cu(a, b, c);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("add", &add, "add(CUDA)");
}
// add_cuda.cu
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include "add_cuda.h"

template <typename scalar_t>
__global__ void add_kernel(
    scalar_t* __restrict__ a, 
    scalar_t* __restrict__ b, 
    scalar_t* __restrict__ c, 
    size_t size
){
    const int index = blockIdx.x * blockDim.x + threadIdx.x;
    const int stride = blockIdx.x * gridDim.x;
    for (int i = index; i < size; i += stride){
        a[i] = b[i] + c[i];
    }
}


void add_cu(torch::Tensor a, torch::Tensor b, torch::Tensor c){
    const auto size = a.size(0);

    const int threads = 8;
    const dim3 blocks((size + threads - 1) / threads);

    AT_DISPATCH_FLOATING_TYPES(a.type(), "add cuda", ([&] {
        add_kernel<scalar_t><<<blocks, threads>>>(
            a.data<scalar_t>(), 
            b.data<scalar_t>(), 
            c.data<scalar_t>(), 
            size
        );
    }));
}

And the test py file:

# test.py
import torch
import add

if __name__ == "__main__":
    a = torch.zeros((100, ))
    b = torch.ones((100, )) * 10
    c = torch.ones((100, ))
    a = a.cuda(1)
    b = b.cuda(1)
    c = c.cuda(1)
    add.add(a, b, c)
    print(a)

Theoretically the print output should be a torch tensor with 100 11s, but actually the output is

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.], device='cuda:1')

It seems that the add_kernel in add_cuda.cu didn’t implement at all, but I don’t know why it didn’t implement.
Any help would be appreciated. Thank you!

Could you check the definition of stride?
It seems it should be defined as the total number of threads in the grid:

const int stride = blockDim.x * gridDim.x

Thank you for your reply, and I did misunderstand it.
But it seems that the most essential problem is not there, because the front elements are 0s after add_kernel, and after I changed stride, the result was still all 0s.

That’s strange. I’ve just tested your code as an inline method and it seems to work:

import torch
from torch.utils.cpp_extension import load_inline

cpp_src = """
void add_cu(torch::Tensor a, torch::Tensor b, torch::Tensor c);

#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

void add(torch::Tensor a, torch::Tensor b, torch::Tensor c){
    CHECK_INPUT(a);
    CHECK_INPUT(b);
    CHECK_INPUT(c);
    add_cu(a, b, c);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("add", &add, "add(CUDA)");
}
"""

cuda_src = """
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

void add_cu(torch::Tensor a, torch::Tensor b, torch::Tensor c);


template <typename scalar_t>
__global__ void add_kernel(
    scalar_t* __restrict__ a, 
    scalar_t* __restrict__ b, 
    scalar_t* __restrict__ c, 
    size_t size
){
    const int index = blockIdx.x * blockDim.x + threadIdx.x;
    const int stride = blockDim.x * gridDim.x;
    for (int i = index; i < size; i += stride){
        a[i] = b[i] + c[i];
    }
}


void add_cu(torch::Tensor a, torch::Tensor b, torch::Tensor c){
    const auto size = a.size(0);

    const int threads = 8;
    const dim3 blocks((size + threads - 1) / threads);

    AT_DISPATCH_FLOATING_TYPES(a.type(), "add cuda", ([&] {
        add_kernel<scalar_t><<<blocks, threads>>>(
            a.data<scalar_t>(), 
            b.data<scalar_t>(), 
            c.data<scalar_t>(), 
            size
        );
    }));
}
"""

add = load_inline(name='add', cpp_sources=[cpp_src],
                   cuda_sources=[cuda_src])

a = torch.zeros((100, ))
b = torch.ones((100, )) * 10
c = torch.ones((100, ))
a = a.cuda(0)
b = b.cuda(0)
c = c.cuda(0)
add.add(a, b, c)
print(a)

Could you check, if it gives the right result and compare your code against mine?

I’ve just found the reason that the program must be used on cuda:0. I’m sorry that I used cuda(1) actually as 0 was used by others then but I changed it to 0 here.
Now when I changed it to cuda(0), everything goes right.
But it is still very strange for me why the cuda kernel doesn’t implement at cudas except cuda(0)? As our GPU usage is of high frequency, it is of importance that these pytorch cuda extensions can be used on any GPU, so how can it be implemented on other GPUs?

I’m sorry that I forgot to say the versions:
OS: ubuntu 16.04 LTS
python: 3.6.8
pytorch: 1.0.0

You probably want to

Best regards

Thomas

P.S.: If you take inspiration from PyTorch internals: The reason you don’t see this in aten/src/Aten/native/cuda/* is that there is a auto-generated wrapper doing it for you.
P.P.S.: The new torchvision C++ extension does it right and is a great example of a model PyTorch extension.

2 Likes

Thank you very much! It works!