Poor Performance with CPU Threads (C++ extension)

I’m trying to rewrite the following sparse multiplication operation as a C++ extension.

def sparse_mul(weights, bias, in_values, active_out_indices):
   # for the ith input in batch, produce output only for the nodes in active_out_indices[i]     
   # active_out_indices.size(1) is expected to be around 10% of weights.size(0)
   active_weights = weights[[active_out_indices]]
   active_bias = bias[[active_out_indices]]
   out_values = torch.bmm(active_weights, in_values.unsqueeze(-1)).squeeze(-1) + active_biases
   return out_values

I want to do this because the above python implementation performs unnecessary data copies due to the advanced indexing. The following is the C++ code I have as of now. It performs almost 50x slower (on my 4 core CPU) than the python implementation and also the regular dense multiplication. Can someone tell me what I’m doing wrong? Can I not expect a speedup here? I’m new to PyTorch C++ and am trying to follow the solution given in https://discuss.pytorch.org/t/using-at-parallel-for-in-a-custom-operator/82747/5

// sparse_mul.cpp
#include <torch/extension.h>
#include <vector>

#define _OPENMP
#include <ATen/ParallelOpenMP.h>

torch::Tensor mul(
        torch::Tensor in_values,
        torch::Tensor active_out_indices,
        torch::Tensor weights,
        torch::Tensor bias) {
    
    auto out_values = torch::empty(active_out_indices.sizes());
    auto active_out_indices_a = active_out_indices.accessor<long,2>();
    int batch_size = active_out_indices_a.size(0);
    int active_out_dim = active_out_indices_a.size(1);

    at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
        for (int64_t i = start; i < end; i++){
            torch::set_num_threads(1);
            auto in_values_i = in_values[i];
            for(int j=0; j<active_out_dim; j++){
                out_values[i][j] = torch::dot(in_values_i, weights[active_out_indices_a[i][j]]) + bias[active_out_indices_a[i][j]];
            } 
        }
    });
    return out_values;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("mul", &mul, "sparse multiplicatio");
}
# setup.py
from setuptools import setup, Extension
from torch.utils import cpp_extension

cpp_module = cpp_extension.CppExtension('sparse_mul_cpp',
                                        sources=['sparse_mul.cpp'],
                                        extra_compile_args=['-fopenmp'],
                                        extra_link_args=['-lgomp']
                                        )

setup(name='sparse_mul_cpp',
      ext_modules=[cpp_module],
      cmdclass={'build_ext': cpp_extension.BuildExtension})

I also get an issue “a leaf Variable that requires grad is being used in an in-place operation error” when I try to use larger sizes for the weight matrix. I’m using Pytorch version 1.6.0

Thank you for your time :slightly_smiling_face:

2 Likes

That code is too high level, so the inner loop sucks.

  1. Dispatch by dtype, like built-in kernels do.
  2. Instead of Tensor.operator[] indexing, use either raw pointers (this implies contiguous tensors and doing address arithmetic) or “accessors” (torch API).
  3. Compute dot product manually
1 Like

Thank you for your advice @googlebot . I was able to achieve a significant speedup after incorporating your suggestions. Here is my current code in case it is of use to anyone. Is there anything else I could do for improving performance even more? I’m now gonna look into implementing backprop for the same.

// sparse_mul.cpp
#include <torch/extension.h>
#include <vector>

#define _OPENMP
#include <ATen/ParallelOpenMP.h>

template <typename scalar_t>
void mul2_kernel(
        torch::Tensor& out_values,
        const torch::Tensor& in_values,
        const torch::Tensor& active_out_indices,
        const torch::Tensor& weights,
        const torch::Tensor& bias){

    int64_t batch_size = out_values.size(0);
    int64_t active_out_dim = out_values.size(1);
    int64_t k_size = weights.size(1);

    auto out_values_0 = out_values.accessor<scalar_t, 2>();
    auto in_values_0 = in_values.accessor<scalar_t, 2>();
    auto active_out_indices_0 = active_out_indices.accessor<int64_t,2>();
    auto weights_0 = weights.accessor<scalar_t, 2>();
    auto bias_0 = bias.accessor<scalar_t, 1>();

    at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end){
        for (int64_t i = start; i < end; i++){
            torch::set_num_threads(1); //is this required
            auto out_values_1 = out_values_0[i];
            auto in_values_1 = in_values_0[i];
            auto active_out_indices_1 = active_out_indices_0[i];
            
            for(int64_t j=0; j<active_out_dim; j++){
                scalar_t &res = out_values_1[j];
                int64_t out_index = active_out_indices_1[j];
                auto weights_1 = weights_0[out_index];

                res = bias_0[out_index];
                for(int64_t k=0; k<k_size; k++)
                    res += in_values_1[k]*weights_1[k];
            }
        }
        // std::cout << "in thread " << omp_get_thread_num() << std::endl;
    });
}

torch::Tensor mul2(
        const torch::Tensor& in_values,
        const torch::Tensor& active_out_indices,
        const torch::Tensor& weights,
        const torch::Tensor& bias){
    
    // std::cout<< "omp_get_max_threads() " << omp_get_max_threads() << std::endl;

    auto out_values = torch::empty(active_out_indices.sizes(), in_values.options());
    AT_DISPATCH_ALL_TYPES(
        in_values.scalar_type(), "mul2", [&] {
            mul2_kernel<scalar_t>(out_values, in_values, active_out_indices, weights, bias);
        });
    return out_values;
}
1 Like

Looks good. Performance-wise, this won’t compile to AVX instructions, but these are tricky. You may check how dot is calculated in BlasKernel.cpp to get an idea.Otherwise, you can use AT_DISPATCH_FLOATING_TYPES and remove set_num_threads line.