Hello,
I would like to customize a MaxPooling function using a C++ extension. Specifically, I would like to retain the last index when the maximum value is achieved, whereas the native implementation’s MaxPool backward function retains the first index. I am working on CPU, and I have provided my code below. I believe my code is running rather slowly. Do you have any suggestions for simple ways to improve CPU speed?
#include <torch/extension.h>
#include <ATen/Parallel.h>
#include <c10/util/Exception.h>
torch::Tensor max_pool_forward(torch::Tensor input, int kernel_size, int stride) {
TORCH_CHECK(input.dim() == 4, "Expected 4D tensor, but got ", input.dim(), "D tensor instead.");
TORCH_CHECK(kernel_size > 0, "kernel size must be greater than zero");
TORCH_CHECK(stride > 0, "stride must be greater than zero");
TORCH_CHECK(input.size(2) >= kernel_size && input.size(3) >= kernel_size,
"input tensor spatial dimensions must be at least kernel size");
auto b = input.size(0);
auto n_C_prev = input.size(1);
auto n_H_prev = input.size(2);
auto n_W_prev = input.size(3);
auto n_H = static_cast<int>((n_H_prev - kernel_size) / stride + 1);
auto n_W = static_cast<int>((n_W_prev - kernel_size) / stride + 1);
auto n_C = n_C_prev;
auto output = torch::zeros({b, n_C, n_H, n_W}, input.options());
at::parallel_for(0, b * n_C * n_H * n_W, input.numel(), [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; ++i) {
int64_t b_idx = i / (n_C * n_H * n_W);
int64_t c_idx = (i / (n_H * n_W)) % n_C;
int64_t hw_idx = i % (n_H * n_W);
int64_t h_idx = hw_idx / n_W;
int64_t w_idx = hw_idx % n_W;
auto vert_start = h_idx * stride;
auto vert_end = vert_start + kernel_size;
auto horiz_start = w_idx * stride;
auto horiz_end = horiz_start + kernel_size;
auto a_prev_slice = input[b_idx][c_idx].slice(0, vert_start, vert_end).slice(1, horiz_start, horiz_end);
output[b_idx][c_idx][h_idx][w_idx] = torch::max(a_prev_slice);
}
});
return output;
}
torch::Tensor max_pool_backward(torch::Tensor input, torch::Tensor grad_output, int kernel_size, int stride) {
auto b = input.size(0);
auto n_C_prev = input.size(1);
auto n_H_prev = input.size(2);
auto n_W_prev = input.size(3);
auto n_C = grad_output.size(1);
auto n_H = grad_output.size(2);
auto n_W = grad_output.size(3);
auto grad_x = torch::zeros({b, n_C_prev, n_H_prev, n_W_prev}, input.options());
auto indices = torch::zeros_like(grad_x);
at::parallel_for(0, b * n_C * n_H * n_W, input.numel(), [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; ++i) {
int64_t b_idx = i / (n_C * n_H * n_W);
int64_t c_idx = (i / (n_H * n_W)) % n_C;
int64_t hw_idx = i % (n_H * n_W);
int64_t h_idx = hw_idx / n_W;
int64_t w_idx = hw_idx % n_W;
auto a_prev = input[b_idx];
auto vert_start = h_idx * stride;
auto vert_end = vert_start + kernel_size;
auto horiz_start = w_idx * stride;
auto horiz_end = horiz_start + kernel_size;
auto a_prev_slice = a_prev[c_idx].slice(0, vert_start, vert_end).slice(1, horiz_start, horiz_end);
auto max_vals = torch::max(a_prev_slice);
auto mask = torch::eq(a_prev_slice, max_vals).to(input.dtype().toScalarType());
auto mask_1d = mask.view({-1});
auto mask_2 = torch::zeros_like(mask_1d);
auto nonzero_indices = torch::nonzero(mask_1d);
auto last_nonzero_index = nonzero_indices[nonzero_indices.size(0) - 1][0].item<int64_t>();
mask_2[last_nonzero_index] = 1.0;
auto new_mask = torch::reshape(mask_2, {2, 2});
auto grad_val = grad_output[b_idx][c_idx][h_idx][w_idx];
auto grad_slice = grad_val * new_mask;
grad_x[b_idx][c_idx].slice(0, vert_start, vert_end).slice(1, horiz_start, horiz_end) += grad_slice;
}
});
return grad_x;
}