Dynamically change num_features in torch.nn.BatchNorm1d

Dear all,

I have used torch.nn.BatchNorm1d in GPyTorch in the following simple class:

 class ExactGPModel(gpytorch.models.ExactGP,
                       botorch.models.gpytorch.GPyTorchModel):
        num_outputs = 1
        def __init__(self, train_x, train_y, likelihood):
            super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
            self.mean_module = gpytorch.means.ConstantMean()
            self.covar_module = gpytorch.kernels.ScaleKernel(
                gpytorch.kernels.RBFKernel())
            # Apply normalization
            self.bn = torch.nn.BatchNorm1d(num_features=2)

        def forward(self, x):
            # Normalize train_x
            print(x.size())
            x = self.bn(x)
            mean_x = self.mean_module(x)
            covar_x = self.covar_module(x)
            return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

where the size of input tensor is torch.Size([20, 2]).

Recently I want to use BoTorch and torch.nn.BatchNorm1d does not work what I intend…

For instance, I want to optimize the acquisition function with BoTorch:

qNIPV, _ = optimize_acqf(
    acq_function=qNIPV,
    bounds=torch.tensor([[kbeg, abeg], [kend, aend]]),
    q=5,  # Number of candidates
    num_restarts=15,
    raw_samples=256,
    options={}
)

and it retuns the following error message:

Traceback (most recent call last):
  File "main_GHH_2x2.py", line 131, in <module>
    options={}
  File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/botorch/optim/optimize.py", line 150, in optimize_acqf
    options=options,
  File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/botorch/optim/initializers.py", line 104, in gen_batch_initial_conditions
    X_rnd[start_idx:end_idx].to(device=device)
  File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/botorch/utils/transforms.py", line 200, in decorated
    return method(cls, X)
  File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/botorch/utils/transforms.py", line 171, in decorated
    return method(cls, X)
  File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/botorch/acquisition/active_learning.py", line 94, in forward
    X=X, sampler=self.sampler, observation_noise=True
  File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/botorch/models/model.py", line 124, in fantasize
    post_X = self.posterior(X, observation_noise=observation_noise)
  File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/botorch/models/gpytorch.py", line 126, in posterior
    mvn = self(X)
  File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/gpytorch/models/exact_gp.py", line 314, in __call__
    full_output = super(ExactGP, self).__call__(*full_inputs, **kwargs)
  File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/gpytorch/module.py", line 28, in __call__
    outputs = self.forward(*inputs, **kwargs)
  File "/home/takafumi/Dropbox/Research/PostDoc/UZH/ML/GPstudy/py/GHHpreferences/TrainGPmodel.py", line 44, in forward
    x = self.bn(x)
  File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 106, in forward
    exponential_average_factor, self.eps)
  File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/torch/nn/functional.py", line 1923, in batch_norm
    training, momentum, eps, torch.backends.cudnn.enabled
RuntimeError: running_mean should contain 25 elements not 2

When I print the size of input tensor, it returns:

torch.Size([20, 2])
torch.Size([20, 2])
torch.Size([20, 2])
torch.Size([20, 2])
torch.Size([20, 2])
torch.Size([20, 2])
torch.Size([256, 25, 2])

where I think 256 comes from raw_samples and 25 is the sum of the original number of training samples (20) and q that I defined.
This change should be done by BoTorch, but torch.nn.BatchNorm1d cannot correctly detect it, I guess.

Does anybody have an idea to circumvent this error?

Thank you very much in advance.