Dear all,
I have used torch.nn.BatchNorm1d
in GPyTorch in the following simple class:
class ExactGPModel(gpytorch.models.ExactGP,
botorch.models.gpytorch.GPyTorchModel):
num_outputs = 1
def __init__(self, train_x, train_y, likelihood):
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(
gpytorch.kernels.RBFKernel())
# Apply normalization
self.bn = torch.nn.BatchNorm1d(num_features=2)
def forward(self, x):
# Normalize train_x
print(x.size())
x = self.bn(x)
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
where the size of input tensor is torch.Size([20, 2])
.
Recently I want to use BoTorch
and torch.nn.BatchNorm1d
does not work what I intend…
For instance, I want to optimize the acquisition function with BoTorch
:
qNIPV, _ = optimize_acqf(
acq_function=qNIPV,
bounds=torch.tensor([[kbeg, abeg], [kend, aend]]),
q=5, # Number of candidates
num_restarts=15,
raw_samples=256,
options={}
)
and it retuns the following error message:
Traceback (most recent call last):
File "main_GHH_2x2.py", line 131, in <module>
options={}
File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/botorch/optim/optimize.py", line 150, in optimize_acqf
options=options,
File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/botorch/optim/initializers.py", line 104, in gen_batch_initial_conditions
X_rnd[start_idx:end_idx].to(device=device)
File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/botorch/utils/transforms.py", line 200, in decorated
return method(cls, X)
File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/botorch/utils/transforms.py", line 171, in decorated
return method(cls, X)
File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/botorch/acquisition/active_learning.py", line 94, in forward
X=X, sampler=self.sampler, observation_noise=True
File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/botorch/models/model.py", line 124, in fantasize
post_X = self.posterior(X, observation_noise=observation_noise)
File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/botorch/models/gpytorch.py", line 126, in posterior
mvn = self(X)
File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/gpytorch/models/exact_gp.py", line 314, in __call__
full_output = super(ExactGP, self).__call__(*full_inputs, **kwargs)
File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/gpytorch/module.py", line 28, in __call__
outputs = self.forward(*inputs, **kwargs)
File "/home/takafumi/Dropbox/Research/PostDoc/UZH/ML/GPstudy/py/GHHpreferences/TrainGPmodel.py", line 44, in forward
x = self.bn(x)
File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 106, in forward
exponential_average_factor, self.eps)
File "/home/takafumi/.pyenv/versions/anaconda3-5.3.1/envs/dice/lib/python3.7/site-packages/torch/nn/functional.py", line 1923, in batch_norm
training, momentum, eps, torch.backends.cudnn.enabled
RuntimeError: running_mean should contain 25 elements not 2
When I print the size of input tensor, it returns:
torch.Size([20, 2])
torch.Size([20, 2])
torch.Size([20, 2])
torch.Size([20, 2])
torch.Size([20, 2])
torch.Size([20, 2])
torch.Size([256, 25, 2])
where I think 256 comes from raw_samples
and 25 is the sum of the original number of training samples (20) and q
that I defined.
This change should be done by BoTorch
, but torch.nn.BatchNorm1d
cannot correctly detect it, I guess.
Does anybody have an idea to circumvent this error?
Thank you very much in advance.