C++ extension performance issue

I’m writing a C++ extension to optimize the performance of a custom function. As a starting point I’m only focused on the CPU version.
My optimized version should in principle be faster by ~30%.
However, it runs in about the same time. Below are my code:

Unoptimized version:

torch::Tensor symsum_forward_unoptimized(torch::Tensor input){
    const int shift = input.size(1) / 2;
    float data[] = {0};
    const torch::Tensor& zero = torch::from_blob(data, {});
    auto relu = torch::maximum(input, zero);
    auto inv_relu = torch::minimum(input, zero);
    auto output = torch::roll(inv_relu, shift, 1) + relu;
    return output;
}

Optimized version:

torch::Tensor symsum_forward_ptr(torch::Tensor input){
    // Create and initialize the output tensor.
    const torch::Tensor& output = torch::zeros_like(input);

    // Batch stride is needed for pointer arithmetic.
    const int batch_stride = input.strides()[0]; 
    const int half_batch_stride = batch_stride / 2;

    // Define and initialize `input` and `output` tensor pointer
    float *in_i1 = input.data_ptr<float>();  // input index 1
    float *in_i2 = in_i1 + half_batch_stride;  // input index 2
    float *out_i1 = output.data_ptr<float>();  //output index 1
    float *out_i2 = out_i1 + half_batch_stride;    //output index 2

   // Define initialize end of batch and end of tensor pointers for `input`
    float *batch_end = in_i2 + half_batch_stride;
    const float* end = input.data_ptr<float>() + input.numel();
     
    // This loop goes over samples in the batch
    while(in_i2 < end){ 
        // This loop goes over elements in a sample
        while (in_i2 < batch_end){
            // `std::signbit` is used instead of `(*in_i1) < 0`. It is faster.
            if (std::signbit(*in_i1)) {
                if (std::signbit(*in_i2)){
                    *out_i1 = *in_i2;
                    *out_i2 = *in_i1;                    
                }
                else{
                    *out_i2 = (*in_i2) + (*in_i1);
                }
            }
            else{
                if (std::signbit(*in_i2)){
                    *out_i1 = (*in_i2) + (*in_i1);
                }
                else{
                    *out_i1 = *in_i1;
                    *out_i2 = *in_i2;                    
                }                
            }
            in_i1++;
            in_i2++;
            out_i1++;
            out_i2++;  
        }

        in_i1 += half_batch_stride;
        in_i2 += half_batch_stride;        
        out_i1 += half_batch_stride;
        out_i2 += half_batch_stride; 
        batch_end += batch_stride;
    }    
    return output;
}

}

The idea behind the optimized version is simple. In the unoptimized version each element of input is compared with zero twice (once in maximum and once in minimum). However, this only needs to happen once. Because if an element is larger or equal to zero, then I know it is not smaller than zero.

The optimized version takes advantage of this insight. It performs the same function, but with only one comparison per element. Furthermore, it allocates one third the memory of the unoptimized one. So it should be faster in principle.
Any insight about why it is not faster is much appreciated.

I’m not sure if your theoretical analysis is correct, as it seems you are comparing a nested loop approach while the native implementation would use the TensorIterator with the at::parallel_for approach.
You could check the source code via max_out, which should call into the max_stub, max_kernel_impl, compare_base_kernel_core, and eventually into TensorIteratorBase::for_each.

even in single thread mode, your code won’t be auto-vectorized, while pytorch utilizes AVX2.

1 Like

Thank you for the comments.

@ptrblck Each iteration of the loop performs the comparison for two elements (that’s why the nested if). So in the end each element is only compared once, and the theoretical analysis is sound I think.

I’m testing the code with torch.set_num_threads(1). So I’m assuming both algorithms are playing in level ground in terms of parallelization(?)

@googlebot Do you anticipate that if I disabled AVX2 (or tried the code on a CPU without AVX instructions), then the optimized code would runs faster?

Um, no, I mean that pytorch processes 8 items at once, that’s potential reason why you can’t reach its speed. It may be possible to rewrite your loop with AVX intinsics, but it is not trivial; it may also auto-vectorize if you replace branches with b?x:y operations, and add __restrict to pointers, I’m just not sure if signbit is supported.

PS Oh, I misundserstood - if you’ll use “slow” pytorch for comparision, then yes, I don’t rule out that your routine will be faster. It screws the branch predictor though.

I tried implementing AVX version. On my PC it runs 3 times faster
Duration symsum_avx: 0.127393 s
Duration symsum_forward_unoptimized: 0.370589 s
EDIT: changed loading and storing to unaligned versions as suggested

here’s the code, if you’re interested

torch::Tensor symsum_avx(torch::Tensor input)
{
	const torch::Tensor& output = torch::empty_like(input);

	const int batch_stride = input.strides()[0];
	const int half_batch_stride = batch_stride / 2;

	float* in_i1 = input.data_ptr<float>(); 
	float* in_i2 = in_i1 + half_batch_stride;
	float* out_i1 = output.data_ptr<float>();
	float* out_i2 = out_i1 + half_batch_stride;

	float* batch_end = in_i2 + half_batch_stride;
	const float* end = input.data_ptr<float>() + input.numel();

	const __m256 avx_zeros = _mm256_setzero_ps();

	float all_bits_on = 0.0f;
	memset(&all_bits_on, ~0, 4);
	const __m256 avx_all_bits_on = _mm256_set1_ps(all_bits_on);

	while (in_i2 < end)
	{
		while ((batch_end - in_i2) >= 8)
		{
			auto in1 = _mm256_loadu_ps(in_i1);
			auto in2 = _mm256_loadu_ps(in_i2);

			const auto b1 = _mm256_cmp_ps(in1, avx_zeros, _CMP_LT_OQ);
			const auto b2 = _mm256_cmp_ps(in2, avx_zeros, _CMP_LT_OQ);

			const auto c1 = _mm256_and_ps(b1, b2);
			const auto c2 = _mm256_andnot_ps(b1, b2);
			const auto c3 = _mm256_andnot_ps(b2, b1); 
			const auto c4 = _mm256_andnot_ps(b1, _mm256_andnot_ps(b2, avx_all_bits_on));

			const auto sum = _mm256_add_ps(in1, in2);

			auto out1 = _mm256_blendv_ps(avx_zeros, in1, c4);
			out1 = _mm256_blendv_ps(out1, sum, c2);
			out1 = _mm256_blendv_ps(out1, in2, c1);

			auto out2 = _mm256_blendv_ps(avx_zeros, in2, c4);
			out2 = _mm256_blendv_ps(out2, sum, c3);
			out2 = _mm256_blendv_ps(out2, in1, c1);

			_mm256_storeu_ps(out_i1, out1);
			_mm256_storeu_ps(out_i2, out2);
			out_i1 += 8; out_i2 += 8; in_i1 += 8; in_i2 += 8;
		}

		while (in_i2 < batch_end)
		{
			const bool b1 = std::signbit(*in_i1);
			const bool b2 = std::signbit(*in_i2);

			const bool c1 = b1 && b2;
			const bool c2 = !b1 && b2;
			const bool c3 = b1 && !b2;
			const bool c4 = !b1 && !b2;

			const float sum = (*in_i2) + (*in_i1);

			*out_i1++ = c1 ? *in_i2 : (c2 ? sum : (c4 ? *in_i1 : *out_i1));
			*out_i2++ = c1 ? *in_i1 : (c3 ? sum : (c4 ? *in_i2 : *out_i2));

			++in_i1;
			++in_i2;
		}

		in_i1 += half_batch_stride;
		in_i2 += half_batch_stride;
		out_i1 += half_batch_stride;
		out_i2 += half_batch_stride;
		batch_end += batch_stride;
	}

	return output;
}

How I tested:


void main()
{
	auto input = torch::randn({ 10000, 10000 });

	auto avx_start = std::chrono::steady_clock::now();
	auto output_avx = symsum_avx(input);
	auto avx_end = std::chrono::steady_clock::now();

	auto default_start = std::chrono::steady_clock::now();
	auto output_default = symsum_forward_unoptimized(input);
	auto default_end = std::chrono::steady_clock::now();

	std::cout << "Duration symsum_avx: " << (std::chrono::duration_cast<std::chrono::microseconds>(avx_end - avx_start).count() / 1000000.0) << " s" << std::endl;
	std::cout << "Duration symsum_forward_unoptimized: " << (std::chrono::duration_cast<std::chrono::microseconds>(default_end - default_start).count() / 1000000.0) << " s" << std::endl;

	std::cout << "Max error: " << (output_avx - output_default).abs().max() << std::endl;
}
1 Like

nice

you may need to use _mm256_loadu_ps / _mm256_storeu_ps I believe

1 Like

I thought that pytorch’s memory is properly aligned for AVX (by looking here). Am I wrong ?

Freshly allocated tensors, yes, but of course, if I slice away the first element, not so much, but I guess your code isn’t completely safe w.r.t. memory layout, anyways.

@Matej_Kompanek wow, you are an angel! On my PC it runs ~4x faster than my unoptimized version. Now I need to decipher this to write the backwards function. Can anyone suggest a good source to learn AVX?

Do you think a similar gain achievable with optimizing the CUDA version?

Do you think a similar gain achievable with optimizing the CUDA version?

No idea, I didn’t expect such an improvement with avx.
I think there is no point in messing around with avx, you should probably focus on CUDA since you’re going to train on GPU anyway (I presume).

1 Like

on CUDA, you gain more from single pass processing (fusion), so you may not need any special “optimizing” beyond that; if you replace your 3 (max,min,+) kernel calls with 1, speedup will be significant, as most overhead is from memory IO (esp. if data doesn’t fit in cache).

As for AVX, “Intel® 64 and IA-32 Architectures Software Developer’s Manual” is the canonical guide, search for pdf, it is on Intel’s site somewhere. Reference is at Intel® Intrinsics Guide

But this is assembler level programming really, as such it is not portable. It is better to write auto-vectorizable code, by satisfying some compiler specific requirements.

1 Like

Is there a way to find out what those compiler specific requirements are for auto vectorization?

Yes, there is usually some documentation, for example Auto-vectorization in GCC - GNU Project.

I tried to vectorize your loop out of curiosity - Compiler Explorer . Haven’t tested it, but it should be the same as with intristics.

Note that I:

  1. added const __restrict modifiers
  2. added an indexing (induction) variable
  3. removed if-else branches
    All this to simplify job of code generator. They fail at auto-vectorizing very easily in all compilers, unfortunately.

I’ll paste just in case

void f(const float * __restrict in_i1, float * __restrict out_i1, int batch_stride, int numel)
{
    const int half_batch_stride = batch_stride / 2;
    const float *in_i2 = in_i1 + half_batch_stride;  // input index 2
    float *out_i2 = out_i1 + half_batch_stride;
    const float *batch_end = in_i2 + half_batch_stride;
    const float* end = in_i1 + numel;

    while(in_i2 < end){ 
        for(int i=0; i<batch_stride; i++) {
            float v1 = in_i1[i];
            float v2 = in_i2[i];
            float sum = v1+v2;
            bool b1 = v1 < 0, b2 = v2 < 0;

            out_i1[i] = b1 ? (b2 ? v2 : sum) : (b2 ? sum : v1);
            out_i2[i] = b1 ? (b2 ? v1 : sum) : (b2 ? sum : v2);
        }
        in_i1 += batch_stride;
        in_i2 += batch_stride;
        out_i1 += batch_stride;
        out_i2 += batch_stride; 

        in_i1 += half_batch_stride;
        in_i2 += half_batch_stride;        
        out_i1 += half_batch_stride;
        out_i2 += half_batch_stride; 
   }
}

I updated the function inspired by your code. It takes 1.5x longer than the unoptimized version to run on my PC. I don’t think it is being vectorized.

torch::Tensor f_forward_vec(torch::Tensor input){
    const torch::Tensor& output = torch::zeros_like(input);
    const int batch_stride = input.strides()[0]; 
    const int half_batch_stride = batch_stride / 2;
    const int onehalf_batch_stride = batch_stride + half_batch_stride;
    float *__restrict__ in_i1 = input.data_ptr<float>();
    float *__restrict__ in_i2 = in_i1 + half_batch_stride;
    float *__restrict__ out_i1 = output.data_ptr<float>();
    float *__restrict__ out_i2 = out_i1 + half_batch_stride;
    const float* end = input.data_ptr<float>() + input.numel();
      
    // This loop goes over samples in the batch
    while(in_i2 < end){ 
        // This loop goes over elements in a sample
        for (int i=0; i<batch_stride; i++){ 
            float v1 = in_i1[i];
            float v2 = in_i2[i];
            bool b1 = std::signbit(v1), b2 = std::signbit(v2);
            out_i1[i] = b1 ? (b2 ? v2 : v1+v2) : (b2 ? v1+v2 : v1);
            out_i2[i] = b1 ? (b2 ? v1 : v1+v2) : (b2 ? v1+v2 : v2);      
        }
        in_i1 += onehalf_batch_stride;
        in_i2 += onehalf_batch_stride;
        out_i1 += onehalf_batch_stride;
        out_i2 += onehalf_batch_stride;
    }    
    return output;
}

It’s no big deal though. I care about the Cuda version which turned out to work great.

Yes, your changes break vectorization - as I said it is very brittle. If you view godbolt. org, my code has characteristic AVX instructions (suffix ‘ps’ for packed single) and “add rsi, 32” line that is a step over 8 floats. Big chunk of code below handles 0-7 tail elements.

More practical way is enabling compiler diagnostics, that may say the reason why auto-vectorization failed (the message may be obscure though).

1 Like

PyTorch has internal vec256 headers which abstract avx and might be a good resource.

Best regards

Thomas

1 Like