Custom MaxPool autograd and C++ extension

Hello,

I would like to customize a MaxPooling function using a C++ extension. Specifically, I would like to retain the last index when the maximum value is achieved, whereas the native implementation’s MaxPool backward function retains the first index. I am working on CPU, and I have provided my code below. I believe my code is running rather slowly. Do you have any suggestions for simple ways to improve CPU speed?

#include <torch/extension.h>
#include <ATen/Parallel.h>
#include <c10/util/Exception.h>

torch::Tensor max_pool_forward(torch::Tensor input, int kernel_size, int stride) {
    TORCH_CHECK(input.dim() == 4, "Expected 4D tensor, but got ", input.dim(), "D tensor instead.");
    TORCH_CHECK(kernel_size > 0, "kernel size must be greater than zero");
    TORCH_CHECK(stride > 0, "stride must be greater than zero");
    TORCH_CHECK(input.size(2) >= kernel_size && input.size(3) >= kernel_size,
                "input tensor spatial dimensions must be at least kernel size");

    auto b = input.size(0);
    auto n_C_prev = input.size(1);
    auto n_H_prev = input.size(2);
    auto n_W_prev = input.size(3);

    auto n_H = static_cast<int>((n_H_prev - kernel_size) / stride + 1);
    auto n_W = static_cast<int>((n_W_prev - kernel_size) / stride + 1);
    auto n_C = n_C_prev;

    auto output = torch::zeros({b, n_C, n_H, n_W}, input.options());


    at::parallel_for(0, b * n_C * n_H * n_W, input.numel(), [&](int64_t start, int64_t end) {
        for (int64_t i = start; i < end; ++i) {
            int64_t b_idx = i / (n_C * n_H * n_W);
            int64_t c_idx = (i / (n_H * n_W)) % n_C;
            int64_t hw_idx = i % (n_H * n_W);

            int64_t h_idx = hw_idx / n_W;
            int64_t w_idx = hw_idx % n_W;

            auto vert_start = h_idx * stride;
            auto vert_end = vert_start + kernel_size;
            auto horiz_start = w_idx * stride;
            auto horiz_end = horiz_start + kernel_size;
            auto a_prev_slice = input[b_idx][c_idx].slice(0, vert_start, vert_end).slice(1, horiz_start, horiz_end);
            output[b_idx][c_idx][h_idx][w_idx] = torch::max(a_prev_slice);
        }
    });

    return output;
}


torch::Tensor max_pool_backward(torch::Tensor input, torch::Tensor grad_output, int kernel_size, int stride) {
    auto b = input.size(0);
    auto n_C_prev = input.size(1);
    auto n_H_prev = input.size(2);
    auto n_W_prev = input.size(3);

    auto n_C = grad_output.size(1);
    auto n_H = grad_output.size(2);
    auto n_W = grad_output.size(3);

    auto grad_x = torch::zeros({b, n_C_prev, n_H_prev, n_W_prev}, input.options());
    auto indices = torch::zeros_like(grad_x);

   at::parallel_for(0, b * n_C * n_H * n_W, input.numel(), [&](int64_t start, int64_t end) {
        for (int64_t i = start; i < end; ++i) {
            int64_t b_idx = i / (n_C * n_H * n_W);
            int64_t c_idx = (i / (n_H * n_W)) % n_C;
            int64_t hw_idx = i % (n_H * n_W);

            int64_t h_idx = hw_idx / n_W;
            int64_t w_idx = hw_idx % n_W;

            auto a_prev = input[b_idx];
            auto vert_start = h_idx * stride;
            auto vert_end = vert_start + kernel_size;
            auto horiz_start = w_idx * stride;
            auto horiz_end = horiz_start + kernel_size;
            auto a_prev_slice = a_prev[c_idx].slice(0, vert_start, vert_end).slice(1, horiz_start, horiz_end);

            auto max_vals = torch::max(a_prev_slice);
            auto mask = torch::eq(a_prev_slice, max_vals).to(input.dtype().toScalarType());
            auto mask_1d = mask.view({-1});

            auto mask_2 = torch::zeros_like(mask_1d);
            auto nonzero_indices = torch::nonzero(mask_1d);
            auto last_nonzero_index = nonzero_indices[nonzero_indices.size(0) - 1][0].item<int64_t>();
            mask_2[last_nonzero_index] = 1.0;

            auto new_mask = torch::reshape(mask_2, {2, 2});

            auto grad_val = grad_output[b_idx][c_idx][h_idx][w_idx];
            auto grad_slice = grad_val * new_mask;

            grad_x[b_idx][c_idx].slice(0, vert_start, vert_end).slice(1, horiz_start, horiz_end) += grad_slice;
        }
    });

    return grad_x;
}

I would take a look to see how your implementation differs from the strategy in the native kernel:

if you haven’t done so already.

UPDATE :

I try to use the MaxPool native. I change if (val > maxval || std::isnan(val)) by if (val >= maxval || std::isnan(val)) to keep the last index instead of the first index.

#include <ATen/ATen.h>
#include <ATen/native/AdaptivePooling.h>
#include <ATen/native/Pool.h>
#include <ATen/native/cpu/utils.h>
#include <ATen/cpu/vec/vec.h>
#include <c10/util/irange.h>
#include <torch/extension.h>


std::tuple<at::Tensor, at::Tensor> max_pool(
    const at::Tensor& input_,
    const int64_t kernel_size,
    const int64_t stride) {

  auto input = input_.contiguous();
  const int64_t ndim = input.ndimension();
  const int64_t channels = ndim == 3 ? input.size(0) : input.size(1) * input.size(2);
  const int64_t input_height = input.size(-2);
  const int64_t input_width = input.size(-1);
  const int64_t output_height = at::native::pooling_output_shape<int64_t>(
      input_height, kernel_size, 0, 1, stride, false);
  const int64_t output_width = at::native::pooling_output_shape<int64_t>(
      input_width, kernel_size, 0, 1, stride, false);

  auto output = at::empty({channels, output_height, output_width}, input.options());
  auto indices = at::empty({channels, output_height, output_width}, input.options().dtype(at::kLong));

  auto input_data = input.data_ptr<int64_t>();
  auto output_data = output.data_ptr<int64_t>();
  auto indices_data = indices.data_ptr<int64_t>();

  const int64_t numel = channels * output_height * output_width;
  at::parallel_for(0, numel, 0, [&](int64_t begin, int64_t end) {
    int64_t c = 0;
    int64_t oh = 0;
    int64_t ow = 0;
    at::native::data_index_init(begin, c, channels, oh, output_height, ow, output_width);
    for (const auto i : c10::irange(begin, end)) {
      int64_t ih0 = oh * stride;
      int64_t iw0 = ow * stride;
      int64_t ih1 = std::min(ih0 + kernel_size, input_height);
      int64_t iw1 = std::min(iw0 + kernel_size, input_width);

      // local pointers
      const int64_t* input_ptr = input_data + c * input_height * input_width;

      // compute local max
      int64_t maxindex = ih0 * input_width + iw0;
      int64_t maxval = -std::numeric_limits<int64_t>::infinity();
      for (int64_t ih = ih0; ih < ih1; ++ih) {
        for (int64_t iw = iw0; iw < iw1; ++iw) {
          int64_t index = ih * input_width + iw;
          int64_t val = static_cast<int64_t>(input_ptr[index]);
          if (val >= maxval || std::isnan(val)) {
            maxval = val;
            maxindex = index;
          }
        }
      }

      // set output to local max and store location of max
      output_data[i] = static_cast<int64_t>(maxval);
      indices_data[i] = maxindex;

      // move on to next output index
      at::native::data_index_step(c, channels, oh, output_height, ow, output_width);
    }
  });

  return std::make_tuple(output, indices);
}



at::Tensor max_pool_backward(
    const at::Tensor& grad_output_,
    const at::Tensor& indices_) {
      
  auto grad_output = grad_output_.contiguous();
  auto indices = indices_.contiguous();
  auto grad_input = at::zeros_like(indices, grad_output.options());

  auto grad_output_data = grad_output.data_ptr<int64_t>();
  auto indices_data = indices.data_ptr<int64_t>();
  auto grad_input_data = grad_input.data_ptr<int64_t>();

  int64_t ndim = grad_output.ndimension();
  // treat batch size and channels as one dimension
  int64_t channels = ndim == 3 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1);
  int64_t input_height = grad_input.size(-2);
  int64_t input_width = grad_input.size(-1);
  int64_t output_height = grad_output.size(-2);
  int64_t output_width = grad_output.size(-1);

  // parallel on dim of N, C
  at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
    for (const auto c : c10::irange(begin, end)) {
      int64_t* grad_input_ptr = grad_input_data + c * input_height * input_width;
      int64_t* grad_output_ptr = grad_output_data + c * output_height * output_width;
      int64_t * indices_ptr = indices_data + c * output_height * output_width;

      for (const auto oh : c10::irange(output_height)) {
        for (const auto ow : c10::irange(output_width)) {
          // retrieve position of max
          int64_t index = oh * output_width + ow;
          int64_t maxindex = indices_ptr[index];
          if (maxindex != -1) {
            // update gradient
            grad_input_ptr[maxindex] += grad_output_ptr[index];
          }
        }
      }
    }
  });
  
  return grad_input;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &max_pool, "CPU MaxPool operation");
  m.def("backward", &max_pool_backward, "CPU MaxPool backward operation");
}

When I want to test my MaxPool c++ extension, I have a type problem. Did someone can help me please ?


class lastindex_pool2d(autograd.Function):
    @staticmethod
    def forward(ctx, input, kernel_size, stride) :
        # Call the forward function in the C++ extension

        ctx.kernel_size = kernel_size
        ctx.stride = stride 
        print(type(stride))
        output, indices = lastindex_pool.forward(input, ctx.kernel_size, ctx.stride)

        # Save the input and kernel size/stride for backward
        ctx.save_for_backward(input, output)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        # Load the saved input and kernel size/stride
        input, output = ctx.saved_tensors
      
        # Call the backward function in the C++ extension
        grad_input = lastindex_pool.backward(input, grad_output,  ctx.kernel_size, ctx.stride)

        return grad_input, None, None


# Test the MaxPool2D function
input = torch.randn(1,1,4,4)
print(input)
input[0][0][0][0] = 2
input[0][0][0][1] = 2
input[0][0][1][0] = 2
input= input.requires_grad_()

print(input)
output = lastindex_pool2d.apply(input, 2, 2)
print(output)
output.sum().backward()
print(input.grad)
RuntimeError                              Traceback (most recent call last)
/var/folders/5t/vn7jb6rx2874dqk3khxl4j540000gn/T/ipykernel_32033/1040729934.py in <module>
     34 
     35 print(input)
---> 36 output = lastindex_pool2d.apply(input, 2, 2)
     37 #print(output)
     38 #output.sum().backward()

/var/folders/5t/vn7jb6rx2874dqk3khxl4j540000gn/T/ipykernel_32033/1040729934.py in forward(ctx, input, kernel_size, stride)
      7         ctx.stride = stride
      8         print(type(stride))
----> 9         output, indices = lastindex_pool.forward(input, ctx.kernel_size, ctx.stride)
     10 
     11         # Save the input and kernel size/stride for backward

RuntimeError: expected scalar type Int but found Float

I don’t know where the problem occurs.

It looks like you are casting your input to int64_t in a lot of places where the original code would use scalar_t or accscalar_t. This may work for initial testing but would also mean that your code would only support integral types.

If you want to just fix the current error I would try specifying dtype=torch.long to torch.randn(1, 1, 4, 4), but keep in mind that if you want to use other data types you would have to modify your extension to generalize the input dtype support as the original native implementation does.

Thanks. I try it but I got this error : “normal_kernel_cpu” not implemented for ‘Long’.

In that case you can use a different initialization function or forcibly coerce it via .to(dtype=torch.long) rather than passing it in the call.

Here is an update with a correct code of the previous problem:

template <typename scalar_t, typename accscalar_t>
std::tuple<torch::Tensor, torch::Tensor> cpu_max_pool_template(
    const torch::Tensor& input_,
    int kernel_size, int stride) {
  auto input = input_.contiguous();

  int64_t kH = kernel_size;
  int64_t kW = kernel_size;
  int64_t dH = stride;
  int64_t dW = stride;

  auto b = input.size(0);
  auto n_C_prev = input.size(1);
  auto n_H_prev = input.size(2);
  auto n_W_prev = input.size(3);

  auto n_H = static_cast<int>((n_H_prev - kernel_size) / stride + 1);
  auto n_W = static_cast<int>((n_W_prev - kernel_size) / stride + 1);
  auto n_C = n_C_prev;

  auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device()).requires_grad(false);
  auto output = torch::empty({b, n_C, n_H, n_W}, options);
  auto indices = torch::empty({b, n_C, n_H, n_W}, torch::kLong);

  auto input_data = input.data_ptr<scalar_t>();
  auto output_data = output.data_ptr<scalar_t>();
  auto indices_data = indices.data_ptr<int64_t>();

  int64_t numel = output.numel();
  int64_t ndim = input.ndimension();
  // treat batch size and channels as one dimension
  int64_t channels = ndim == 3 ? input.size(0) : input.size(0) * input.size(1);
  int64_t input_height = input.size(-2);
  int64_t input_width = input.size(-1);
  int64_t output_height = output.size(-2);
  int64_t output_width = output.size(-1);

  // parallel on dim N, C, H, W
  at::parallel_for(0, numel, 0, [&](int64_t begin, int64_t end) {
    int64_t c = 0;
    int64_t oh = 0;
    int64_t ow = 0;
    at::native::data_index_init(begin, c, channels, oh, output_height, ow, output_width);

    for (const auto i : c10::irange(begin, end)) {
      int64_t ih0 = oh * dH;
      int64_t iw0 = ow * dW;
      int64_t ih1 = std::min(ih0 + kH, input_height);
      int64_t iw1 = std::min(iw0 + kW, input_width);

      // local pointers
      scalar_t* input_ptr = input_data + c * input_height * input_width;

      // compute local max
      int64_t maxindex = ih0 * input_width + iw0;
      accscalar_t maxval = input_ptr[maxindex];
      for (int64_t ih = ih0; ih < ih1; ++ih) {
        for (int64_t iw = iw0; iw < iw1; ++iw) {
          int64_t index = ih * input_width + iw;
          accscalar_t val = input_ptr[index];
          if (val >= maxval || std::isnan(val)) {
            maxval = val;
            maxindex = index;
          }
        }
      }

      // set output to local max and store location of max
      output_data[i] = scalar_t(maxval);
      indices_data[i] = maxindex;

      // move on to next output index
      at::native::data_index_step(c, channels, oh, output_height, ow, output_width);
    }
  });

  return std::make_tuple(output, indices);
}

// Define a function that calls cpu_max_pool_template with the appropriate scalar types
std::tuple<torch::Tensor, torch::Tensor> cpu_max_pool(
    const torch::Tensor& tensor, int kernel_size, int stride) {
  TORCH_CHECK(tensor.dim() == 4, "Only 4D input Tensors are supported (nbatch, channels, height, width)");

  // Call cpu_max_pool_template with the appropriate scalar types
  return AT_DISPATCH_FLOATING_TYPES(tensor.type(), "cpu_max_pool", [&] {
    return cpu_max_pool_template<scalar_t, scalar_t>(tensor, kernel_size, stride);
  });
}