PyTorch C++ modification of tensors

I consider AlexNet from GitHub - bhiziroglu/Image-Classification-with-Deep-Convolutional-Neural-Networks: C++ / LibTorch implementation of AlexNet and want to replace Conv2d with the custom convolution, where the modification of input tensor occurs before applying Conv2d. To this end, I created the module:

struct MyConv : torch::nn::Module {
            MyConv() {}
                MyConv(int in_ch, int out_ch, int kernel, int stride, int pad)
                       : real_conv(torch::nn::Conv2dOptions(in_ch, out_ch, kernel)
                        .stride(stride)
                        .padding(pad)
                        .bias(true))
        {
         register_module("custom_conv", real_conv);
        }
            torch::Tensor forward(torch::Tensor x)
            {
                  int N = 8;
                  int K = 1;
                  torch::Tensor res = x.detach().clone();
                  int col_num = res.sizes()[3];
                  torch::Tensor augmented = torch::zeros({res.sizes()[0],res.sizes()[1],res.sizes()[2],((col_num + N - 1) / N) * N},
                                  torch::Device(torch::kCUDA, 0));
                  for(int i=0;i<res.sizes()[0];++i)
                  {
                          for(int j=0;j<res.sizes()[1];++j)
                          {
                                   torch::Tensor padded = torch::nn::functional::pad(res.index({i,j,"...","..."}),
                                                  torch::nn::functional::PadFuncOptions({0,((col_num + N - 1) / N) * N - col_num}).mode(torch::kReplicate));
                                  augmented.index_put_({i,j,"...","..."},padded);
                                  torch::Tensor tnsr_r = torch::reshape(augmented.index({i,j,"...","..."}),
                                                  {augmented.sizes()[2] * augmented.sizes()[3] / N, N});
                                  torch::Tensor max_v = std::get<0>(torch::max(tnsr_r,1));
                                  max_v = torch::unsqueeze(max_v,1);
                                  max_v = max_v.expand({max_v.sizes()[0],N});
                                  max_v = torch::reshape(max_v,{augmented.sizes()[2],augmented.sizes()[3]});
                                  torch::Tensor min_v = std::get<0>(torch::min(tnsr_r,1));
                                  min_v = torch::unsqueeze(min_v,1);
                                  min_v = min_v.expand({min_v.sizes()[0],N});
                                  min_v = torch::reshape(min_v,{augmented.sizes()[2],augmented.sizes()[3]});
                                  torch::Tensor steps = torch::add(max_v,-min_v);
                                  steps = torch::div(steps,1<<K);
                                  torch::Tensor steps_inv = steps.pow(-1);
                                  steps_inv = torch::nan_to_num(steps_inv,0,0,0);
                                  torch::Tensor inds = torch::round(torch::mul(augmented.index({i,j,"...","..."})-min_v,steps_inv));
                                  augmented.index_put_({i,j,"...","..."},torch::add(min_v,torch::mul(inds,steps)));
                                  }
                  }
                  res.index_put_({torch::indexing::Slice(),torch::indexing::Slice(),torch::indexing::Slice(),torch::indexing::Slice()},
                                  augmented.index({torch::indexing::Slice(),torch::indexing::Slice(),torch::indexing::Slice(),torch::indexing::Slice(0,col_num,1)}));
                  return real_conv(res);
            }
     torch::nn::Conv2d real_conv{nullptr};
};

The usage in AlexNet is as follows:

class AlexNet : public torch::nn::Module {
    // Modified AlexNet for Cifar dataset
    public:
        explicit AlexNet(int64_t num_classes=100);
        Tensor forward(Tensor x);

    private:
        torch::nn::Sequential features{
            //torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 64, 3).stride(2).padding(1)),
            MyConv(3, 64, 3, 2, 1),
            torch::nn::ReLU(torch::nn::ReLUOptions(true)),
            torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions({2, 2})),

            //torch::nn::Conv2d(torch::nn::Conv2dOptions(64, 192, 3).padding(1)),
            MyConv(64, 192, 3, 1, 1),
            torch::nn::ReLU(torch::nn::ReLUOptions(true)),
            torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions({2, 2})),

            //torch::nn::Conv2d(torch::nn::Conv2dOptions(192, 384, 3).padding(1)),
            MyConv(192, 384, 3, 1, 1),
            torch::nn::ReLU(torch::nn::ReLUOptions(true)),

            //torch::nn::Conv2d(torch::nn::Conv2dOptions(384, 256, 3).padding(1)),
            MyConv(384, 256, 3, 1, 1),
            torch::nn::ReLU(torch::nn::ReLUOptions(true)),

            //torch::nn::Conv2d(torch::nn::Conv2dOptions(256, 256, 3).padding(1)),
            MyConv(256, 256, 3, 1, 1),
            torch::nn::ReLU(torch::nn::ReLUOptions(true)),

            torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions({2, 2}))
        };

        torch::nn::Sequential classifier{
            torch::nn::Dropout(torch::nn::Dropout2dOptions().p(0.5)),
            torch::nn::Linear(torch::nn::LinearOptions(256 * 2 * 2, 4096)),
            torch::nn::ReLU(torch::nn::ReLUOptions(true)),
            torch::nn::Dropout(torch::nn::Dropout2dOptions().p(0.5)),
            torch::nn::Linear(torch::nn::LinearOptions(4096, 4096)),
            torch::nn::ReLU(torch::nn::ReLUOptions(true))
        };

        torch::nn::Linear fc;
};

This action dramatically reduces GPU utilization. I’m wondering how to improve the perfomance in this case.

If I understand your code correctly you are replacing the conv layer (using e.g. a cuDNN kernel) with a manual implementation uses a nested for loop to iterate each pixel location launching ~17 operations (and thus kernels).
You might need to rewrite your custom operation into a custom C++/CUDA extension if possible. Alternatively, you could also try to use CUDA Graphs to at least reduce the dispatching and launch overheads caused in your implementation. Also, getting rid of the nested for loop operating on each pixel location and trying to use vectorized operations should also give you a speedup.

Thank you, I managed to solve the problem by avoiding nested loops.