Preprocess input for pretrained networks using batchnorm

Hi all,
I am trying to preprocess input images in the model using batch_norm instead of normalizing the input tensor using mean/std values. I have the following code in python which works as intended:

class PreProcess(torch.nn.Module):
    def __init__(self,mean,var):
        super(PreProcess, self).__init__()
        self.mean = torch.FloatTensor(mean)
        self.var = torch.FloatTensor(var)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        x = x.permute(0, 3, 1, 2)

        x = torch.nn.functional.batch_norm(x,
                                           self.mean,
                                           self.var,
                                           weight=None,
                                           bias=None,
                                           training=False,
                                           momentum=0.1,
                                           eps=1e-05)

        return x

def load_model(model_path, mean, std):
    model = torch.jit.load(model_path)
    var = [x*x for x in std]
    ppmodel = torch.nn.Sequential(PreProcess(mean, var), model)
    ppmodel.eval()
    return ppmodel

I want to do the same using libtorch in C++, so I wrote the following C++ code:

struct PreProcessImpl : torch::nn::Module {
  torch::Tensor forward(torch::Tensor x) {
    x = at::permute(x, {0, 3, 1, 2});

    namespace F = torch::nn::functional;

    x = F::batch_norm(x,
                torch::ones({3}),
                torch::ones({3}),
                F::BatchNormFuncOptions().momentum(0.1).eps(1e-05).training(false));

    return x;
  }
};

TORCH_MODULE(PreProcess);
PreProcess pp;
auto ppmodule = torch::nn::Sequential(pp, torch::jit::load(modelFile));

However, it won’t compile because of this error:

torch/include/torch/csrc/api/include/torch/nn/modules/container/sequential.h:250:8:
note: candidate function not viable: no known conversion from 'torch::jit::Module'
to 'torch::nn::AnyModule' for 1st argument
  void push_back(AnyModule any_module) {
       ^

Any ideas on how to do this in C++ using libtorch?