C++ is slower than corresponding py version

I have developed a batch linear regression algorithm in c++ which reads

#include <torch/extension.h>
/// param x: (batch, sample, feature)
/// param y: (batch, sample)
/// return (batch, feature)
at::Tensor lr_batch_fit(const at::Tensor &x, const at::Tensor &y){
    const int64_t batch = x.size(0), sample = x.size(1), features = x.size(2);

    const at::Tensor & xt = at::transpose(x, 1, 2);
    const at::Tensor & s = at::bmm(xt, x);
    const at::Tensor & yp = y.view({batch, sample, 1});
    const at::Tensor & pre_beta = at::bmm(xt, yp);

    at::Tensor s_lu, pivots, infos;
    std::tie(s_lu, pivots, infos) = at::_lu_with_info(s, true, true);
    /*
    const auto data = at::_lu_with_info(s, true, true);
    const at::Tensor & s_lu = std::get<0>(data);
    const at::Tensor & pivots = std::get<1>(data);
    const at::Tensor & infos = std::get<2>(data);
    */

    at::Tensor beta = at::zeros({batch, features, 1}, x.options());
    if(0 == infos.nonzero().size(0)){
        at::lu_solve_out(beta, pre_beta, s_lu, pivots);
    }
    return beta.squeeze(2);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("lr_batch_fit_c",   &lr_batch_fit);
}

The corrresponding python code is

import torch

def lr_batch_fit_py(x, y):
    """
    :param x: (batch, sample, features)
    :param y: (batch, sample,)
    :return: beta (batch, features)
    """
    batch, sample, features = x.shape
    xt = torch.transpose(x, 1, 2)
    s = torch.bmm(xt, x)
    y = y.view(batch, sample, 1)
    pre_beta = torch.bmm(xt, y)
    s_lu, pivots, infos = torch.lu(s, pivot=True, get_infos=True)

    beta = torch.zeros(batch, features, 1, dtype=x.dtype, device=x.device)
    if 0 == infos.nonzero().size(0):
        torch.lu_solve(pre_beta, s_lu, pivots, out=beta)
    return beta.squeeze(2)

and here is the testing code

import torch
import time
from lr_batch_c import lr_batch_fit_c
from lr_batch import lr_batch_fit_py

def test_lr_batch():
    def check_equal(x, y, eps=1e-8):
        assert torch.sum(torch.abs(x - y)) <= eps

    tt1, tt2 = 0, 0
    x = torch.randn(4, 48000, 160, dtype=torch.float64).cuda()
    beta = torch.randn(4, 160, 1, dtype=torch.float64).cuda()
    y = torch.bmm(x, beta).squeeze(2)
    beta = beta.squeeze(2)
    for i in range(100):
        t1 = time.time()
        beta2 = lr_batch_fit_c(x, y)
        t2 = time.time()
        beta3 = lr_batch_fit_py(x, y)
        t3 = time.time()
        tt1 += (t2 - t1)
        tt2 += (t3 - t2)
        check_equal(beta, beta2)
        check_equal(beta, beta3)
    print("test_lr_batch  c_time %.2fs py_time %.2fs" % (tt1, tt2))

After running test_lr_batch several times with cuda, I found that the c++ version is slower than the py version. In my naive thinking, since c++ version will save some time at the boundary between py and c++, it should be faster. Could anyone help me figure it out?
(Note with cpu, the c++ version is a little bit faster than py verion, as anticipated.)
Thanks.

Hi,

If you’re runnning on cuda, you need to add the appropriate torch.cuda.synchronize() to make your timings accurate (keep in mind that all the CUDA api is asynchronous).
Also how big a difference do you see?

Hi,
my testing result is

test_lr_batch  c_time 1.54s py_time 1.20s

If i change x to torch.randn(4, 4800, 160), the result is

test_lr_batch  c_time 0.57s py_time 0.24s

and it seems the difference is (somehow) significant.
(by the way, i am using pytorch 1.4 + ( NVIDIA-SMI 430.40 Driver Version: 430.40 CUDA Version: 10.1) + gcc 7.5.0 + Ubuntu 18.04.4 LTS)

I added torch.cuda.synchronize() at different locations in the function test_lr_batch and it makes no change to the results.

Hello,

Maybe a stupid question : are you sure your C++ app runs on CUDA? Release mode?

I am surprised because my C++ app runs much faster than python (but that’s for many different reasons).

Pascal

Thanks for your kind reminder and here are my double checks.
First, the setup file is

from setuptools import setup
from torch.utils.cpp_extension import CppExtension, BuildExtension

setup(name='lr_batch_c',
      ext_modules=[CppExtension('lr_batch_c', ['lr_batch.cpp'])],
      cmdclass={'build_ext': BuildExtension})

and I saw flags -DNDEBUG -O3 when using python setup.py build_ext --inplace to compile the cpp file.

Second, except the explicit transferring of x, y to cuda, all other operations are using standard pytorch api’s and those operations should runs on CUDA, right?

Step by step testing, it looks like the cpp version of lu

at::Tensor s_lu, pivots, infos;
    std::tie(s_lu, pivots, infos) = at::_lu_with_info(s, true, true);

is slower than the py version

    s_lu, pivots, infos = torch.lu(s, pivot=True, get_infos=True)

You need to put a torch.cuda.synchronize() before every call to time.time() to get actuall timings. Otherwise you’ll just see random artifacts from the time it takes to queue stuff on the GPU and when that queue is full.

Here is my new test_lr_batch function

def test_lr_batch():
    def check_equal(x, y, eps=1e-8):
        assert torch.sum(torch.abs(x - y)) <= eps

    tt1, tt2 = 0, 0
    x = torch.randn(4, 4800, 160, dtype=torch.float64).cuda()
    beta = torch.randn(4, 160, 1, dtype=torch.float64).cuda()
    y = torch.bmm(x, beta).squeeze(2)
    beta = beta.squeeze(2)
    torch.cuda.synchronize()
    for i in range(200):
        torch.cuda.synchronize()
        t1 = time.time()

        beta2 = lr_batch_fit_c(x, y)

        torch.cuda.synchronize()
        t2 = time.time()

        beta3 = lr_batch_fit_py(x, y)

        torch.cuda.synchronize()
        t3 = time.time()
        tt1 += (t2 - t1)
        tt2 += (t3 - t2)
        check_equal(beta, beta2)
        check_equal(beta, beta3)
    print("test_lr_batch  c_time %.3fs py_time %.3fs" % (tt1, tt2))

Based on the new function, the cpp version is still slower than the py version.

Is the new function fine for testing, or should I also put similar stuff ( torch.cuda.synchronize() ) in the cpp code?

1 Like

Can you test your c++ code under c++ env without binding? (use libtorch)
The python apis and c++ apis actually hit the same underlying c++ code, the slowness is probably due to python binding for cpp extension.

I don’t think the slowness is due to the python binding since I tested the following c++ code

#include <torch/extension.h>
/// param x: (batch, sample, feature)
/// param y: (batch, sample)
/// return (batch, feature)
at::Tensor lr_batch_fit(const at::Tensor &x, const at::Tensor &y){
    const int64_t batch = x.size(0), sample = x.size(1), features = x.size(2);

    const at::Tensor & xt = at::transpose(x, 1, 2);
    const at::Tensor & s = at::bmm(xt, x);
    const at::Tensor & yp = y.view({batch, sample, 1});
    const at::Tensor & pre_beta = at::bmm(xt, yp);

    at::Tensor beta = at::zeros({batch, features, 1}, x.options());
    return beta.squeeze(2);

and the corresponding python version

import torch

def lr_batch_fit_py(x, y):
    """
    :param x: (batch, sample, features)
    :param y: (batch, sample,)
    :return: beta (batch, features)
    """
    batch, sample, features = x.shape
    xt = torch.transpose(x, 1, 2)
    s = torch.bmm(xt, x)
    y = y.view(batch, sample, 1)
    pre_beta = torch.bmm(xt, y)
    beta = torch.zeros(batch, features, 1, dtype=x.dtype, device=x.device)
    return beta.squeeze(2)

both versions have mildly difference and c++ wins the race.

Hi, I think there are some difference between torch.lu(s, pivot=True, get_infos=True) and at::_lu_with_info(s, true, true)

check here:

If get_infos is True, then check_error flag in at version should be false. So you should use at::_lu_with_info(s, true, false) in the test.

Let me know if the performance goes back to normal after this change.

1 Like

Thanks for pointing out the difference, but the performance remains the same after this change, so sad.

Note that as the heavy lifting is done by the exact same function, the only thing you measure is the overhead of calling the function.
And while python overhead is significant if you loop over very simple ops that take micro seconds. If your function actually does stuff, you won’t even see it.