I’m writing a C++ extension to optimize the performance of a custom function. As a starting point I’m only focused on the CPU version.

My optimized version should in principle be faster by ~30%.

However, it runs in about the same time. Below are my code:

**Unoptimized version:**

```
torch::Tensor symsum_forward_unoptimized(torch::Tensor input){
const int shift = input.size(1) / 2;
float data[] = {0};
const torch::Tensor& zero = torch::from_blob(data, {});
auto relu = torch::maximum(input, zero);
auto inv_relu = torch::minimum(input, zero);
auto output = torch::roll(inv_relu, shift, 1) + relu;
return output;
}
```

**Optimized version:**

```
torch::Tensor symsum_forward_ptr(torch::Tensor input){
// Create and initialize the output tensor.
const torch::Tensor& output = torch::zeros_like(input);
// Batch stride is needed for pointer arithmetic.
const int batch_stride = input.strides()[0];
const int half_batch_stride = batch_stride / 2;
// Define and initialize `input` and `output` tensor pointer
float *in_i1 = input.data_ptr<float>(); // input index 1
float *in_i2 = in_i1 + half_batch_stride; // input index 2
float *out_i1 = output.data_ptr<float>(); //output index 1
float *out_i2 = out_i1 + half_batch_stride; //output index 2
// Define initialize end of batch and end of tensor pointers for `input`
float *batch_end = in_i2 + half_batch_stride;
const float* end = input.data_ptr<float>() + input.numel();
// This loop goes over samples in the batch
while(in_i2 < end){
// This loop goes over elements in a sample
while (in_i2 < batch_end){
// `std::signbit` is used instead of `(*in_i1) < 0`. It is faster.
if (std::signbit(*in_i1)) {
if (std::signbit(*in_i2)){
*out_i1 = *in_i2;
*out_i2 = *in_i1;
}
else{
*out_i2 = (*in_i2) + (*in_i1);
}
}
else{
if (std::signbit(*in_i2)){
*out_i1 = (*in_i2) + (*in_i1);
}
else{
*out_i1 = *in_i1;
*out_i2 = *in_i2;
}
}
in_i1++;
in_i2++;
out_i1++;
out_i2++;
}
in_i1 += half_batch_stride;
in_i2 += half_batch_stride;
out_i1 += half_batch_stride;
out_i2 += half_batch_stride;
batch_end += batch_stride;
}
return output;
}
}
```

The idea behind the optimized version is simple. In the unoptimized version each element of `input`

is compared with zero twice (once in `maximum`

and once in `minimum`

). However, this only needs to happen once. Because if an element is larger or equal to zero, then I know it is not smaller than zero.

The optimized version takes advantage of this insight. It performs the same function, but with only one comparison per element. Furthermore, it allocates one third the memory of the unoptimized one. So it should be faster in principle.

Any insight about why it is not faster is much appreciated.