C++ Custom Operator is slow in this case

#include <torch/extension.h>

#include <iostream>
#include <vector>

std::vector<at::Tensor> lltm_forward(torch::Tensor dummy, 
torch::Tensor mult1, int element_per_filter, int mult1_shape_0, 
int y_index_2, int mult1_shape_1)
{
    // auto dummy1 = dummy.accessor<float, 2>();
    // auto mult11 = mult1.accessor<float, 2>();
    int iterator = 0;
    float number = 0;
    //mult1 = mult1.data<float>;
    int temp = 0;
    int filternumber = -1;
    for (int i = 0; i < mult1_shape_0 + 1; i = i + y_index_2)
    {
        for (int j = 0; j < mult1_shape_1; j++)
        {
            number = 0;

            for (int k = iterator; k < i; k++)
            {
                number = number + mult1[k][j].item().to<double>();
                temp = k;
            }
            dummy[temp][j] = number;


    return {dummy};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    m.def("nice", &lltm_forward, "LLTM forward");
}

This code works but it around 400 x slower than it’s python variant. When i replace mult1[k][j].item().to(); with a float number (let say i from the loop) it becomes 66% faster which is what i want from this operator. I was thinking if a float tensor is passed at the python end instead of torch::Tensor mult1 and i loop through it it would be faster (by just doing mult1[k][j]), i have tried to convert mult1 to float as follows but couldn’t make that work either.

    auto x = mult1.to(torch::kFloat32);

Which still gives:

: error: cannot convert ‘at::Tensor’ to ‘float’ in assignment
                 number = number + x[k][j]; //.item().to<double>();

Guidance on this issue is required from respected members. Is it possible to send numpy arrays to C++ operator perform looping on it and return a numpy array. I have tried that too but could not make to work yet.

Hi,

Is the input Tensor on GPU by any chance?
Also what is the python counterpart that is faster?

1 Like

Thank you for your reply. Actually Tensor were converted to numpy for python looping, but i was thinking c++ loops are much better so a custom operator is a good choice. Python loops run on numpy arrays. Here is a snippet to demonstrate both parts

    import time

    #######################
    # Calling C++ made operator here
    a = torch.from_numpy(dummy).to(torch.float)
    b = torch.from_numpy(mult1).to(torch.float)
    start = time.time()  # Torch passed to c++ operator a and b
    nani = nice.nice(
        a, b, element_per_filter, mult1.shape[0], yindex_2, mult1.shape[1])
    #######################
    end = time.time()
    print(end - start, "C++")

    start = time.time()

    for i in range(0, mult1.shape[0]+1, yindex_2):   # python looping on numpy arrays

        for j in range(0, mult1.shape[1], 1):                
            number = 0
            for k in range(iterator,   i,  1):            
                number = number +mult1[k][j]          
                temp = k
            dummy[temp][j] = number
    end = time.time()
    print(end - start, "Python Loop")

I had similar experience with libtorch. Looking at your code, you are not using any libtorch specific functionality that binds to using tensors only. So you should be able to easily swap tensor with a vector.
Just converting the mult1 from a tensor into a vector, would greatly boost your performance.

1 Like

If it’s possible can you please provide a snippet wrt to this code. I am not using any tensor related functionality that’s why i want to move these loops to c++

I tried rewriting it with raw arrays (without really knowing what sizes of tensors you are using), seems much faster

std::vector<at::Tensor> lltm_forward2(torch::Tensor dummy, torch::Tensor mult1, int element_per_filter, int mult1_shape_0, int y_index_2, int mult1_shape_1)
{
	TORCH_CHECK(dummy.is_contiguous(), " dummy must be a continous tensor");
	TORCH_CHECK(mult1.is_contiguous(), " mult1 must be a continous tensor");
	TORCH_CHECK(!dummy.is_cuda(), " dummy can't be CUDA tensor");
	TORCH_CHECK(!mult1.is_cuda(), " mult1 can't be CUDA tensor");

	auto dummyPtr = dummy.data_ptr<float>();
	auto mult1Ptr = mult1.data_ptr<float>();

	const size_t dummyPtr_stride = dummy.size(1);
	const size_t mult1Ptr_stride = mult1.size(1);

	int iterator = 0;
	float number = 0.0f;

	int temp = 0;
	int filternumber = -1;

	for (int i = 0; i < mult1_shape_0 + 1; i = i + y_index_2)
	{
		for (int j = 0; j < mult1_shape_1; j++)
		{
			number = 0.0f;

			for (int k = iterator; k < i; k++)
			{
				number += mult1Ptr[k * mult1Ptr_stride + j];
				temp = k;
			}

			dummyPtr[temp * dummyPtr_stride + j] = number;
		}
	}

	return { dummy };
}

Tested it with this:

void main(int argc, char** argv)
{
	try
	{
		int SZ1 = 30;
		int SZ2 = 40;

		torch::Tensor dummySrc = torch::randn({ SZ1, SZ2 });
		torch::Tensor mult1Src = torch::randn({ SZ1, SZ2 });

		torch::Tensor result1, result2;

		{
			auto dummy = dummySrc.detach().clone();
			auto mult1 = mult1Src.detach().clone();
			std::chrono::steady_clock::time_point begin1 = std::chrono::steady_clock::now();
			auto output1 = lltm_forward(dummy, mult1, 0, mult1.size(0), 1, mult1.size(1));
			std::chrono::steady_clock::time_point end1 = std::chrono::steady_clock::now();
			std::cout << "Time difference = " << std::chrono::duration_cast<std::chrono::microseconds>(end1 - begin1).count() << "[µs]" << std::endl;
			result1 = output1[0];

			//std::cout << output1[0] << std::endl;
		}

		{
			auto dummy = dummySrc.detach().clone();
			auto mult1 = mult1Src.detach().clone();
			std::chrono::steady_clock::time_point begin2 = std::chrono::steady_clock::now();
			auto output2 = lltm_forward2(dummy, mult1, 0, mult1.size(0), 1, mult1.size(1));
			std::chrono::steady_clock::time_point end2 = std::chrono::steady_clock::now();
			std::cout << "Time difference = " << std::chrono::duration_cast<std::chrono::microseconds>(end2 - begin2).count() << "[µs]" << std::endl;

			result2 = output2[0];

			//std::cout << output2[0] << std::endl;
		}

		std::cout << "Max error: " << (result2 - result1).abs().max() << std::endl;

		system("PAUSE"); return;
	}
	catch (std::runtime_error& e)
	{
		std::cout << e.what() << std::endl;
	}
	catch (const c10::Error& e)
	{
		std::cout << e.msg() << std::endl;
	}

	system("PAUSE");
}

EDIT: fixed stride values
EDIT2: changes from C++ assert to TORCH_CHECK

1 Like

Thank you very much. This is exactly what i needed. Just one more question how to adjust the stride for

`number += mult1Ptr[k * mult1Ptr_stride + j];         // mult1[k][j]

dummyPtr[temp * dummyPtr_stride + j] = number;`   // dummy[temp][j] = number

As i don’t get the same answer for it as compared to the original loop.
(I am not very efficient in using pointers)

Sorry, it should have been

const size_t dummyPtr_stride = dummy.size(1);
const size_t mult1Ptr_stride = mult1.size(1);

I was using same size for both dimensions so it worked.
I also changed it in the previous post

1 Like

You are awesome. :slight_smile: Thank you it is

0.0011734962463378906 C++
0.8707504272460938 Python Loop

Edit: Entered Values of speedup.