Libtorch Error is "Expected more than 1 value per channel when training, got input size [1, 256]"

class MODEL : public torch::nn::Module {
public:
    torch::Tensor forward(torch::Tensor x);

private:
    torch::nn::Linear fc1 = torch::nn::Linear(512, 256);
    torch::nn::Linear fc2 = torch::nn::Linear(256, 128);
    torch::nn::Linear fc3 = torch::nn::Linear(128, 4096);
    torch::nn::ReLU relu = torch::nn::ReLU();

    torch::nn::BatchNorm1d bn1 = torch::nn::BatchNorm1d(torch::nn::BatchNorm1dOptions(256).track_running_stats(false));
    torch::nn::BatchNorm1d bn2 = torch::nn::BatchNorm1d(torch::nn::BatchNorm1dOptions(128).track_running_stats(false));

};

torch::Tensor MODEL ::forward(torch::Tensor x)
{
    x = fc1->forward(x);
    x = bn1->forward(x);
    x = torch::nn::functional::relu(x);
    x = fc2->forward(x);
    cout << bn2->is_training() ;  // false
    x = bn2->forward(x);     // -> Expected more than 1 value per channel when training, got input size [1, 256]
    x = torch::nn::functional::relu(x);
    x = fc3->forward(x);

    return x;
}

MODEL model = MODEL();
model.eval();
auto ouput = model.forward(X);  // X [1, 512]

I encountered an error " Expected more than 1 value per channel when training "
Is there a solution or something mistake?

Increase the batch size as batchnorm layers need to calculate the stats to normalize the input activation from multiple elements for each channel.
Alternatively, remove the batchnorm layers if you cannot increase the number of samples.

Thank you for your answer. I have an additional question.

  1. Why doesn’t this error occur here? ( x = bn1->forward(x); )

  2. I code conversion and loaded the weights for use with a pre-trained model on Python in C++…
    Is there a difference from Python’s results if the batchNorm layer is removed from C++?

  1. It does already fail in bn1:
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 4096)
        
        self.bn1 = nn.BatchNorm1d(256, track_running_stats=False)
        self.bn2 = nn.BatchNorm1d(128, track_running_stats=False)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x
    
model = MyModel()
x = torch.randn(1, 512)
out = model(x)
# Traceback (most recent call last):
#   ...
#   File "/tmp/ipykernel_162908/2678004044.py", line 13, in forward
#     x = self.bn1(x)
# ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 256])
  1. Yes, removing the batchnorm layer won’t normalize the activation and will thus return a different output.

You can make it work (in eval mode) by registering your modules and removing track_running_stats (can’t compute variance on batch size 1), e.g.:

class MODEL : public torch::nn::Module
{
public:
	MODEL()
	{
		fc1 = register_module("fc1", torch::nn::Linear(512, 256));
		fc2 = register_module("fc2", torch::nn::Linear(256, 128));
		fc3 = register_module("fc3", torch::nn::Linear(128, 4096));
		bn1 = register_module("bn1", torch::nn::BatchNorm1d(torch::nn::BatchNorm1dOptions(256)));
		bn2 = register_module("bn2", torch::nn::BatchNorm1d(torch::nn::BatchNorm1dOptions(128)));
	}

	torch::Tensor forward(torch::Tensor x)
	{
		x = fc1->forward(x);
		x = bn1->forward(x);
		x = torch::nn::functional::relu(x);
		x = fc2->forward(x);
		x = bn2->forward(x);
		x = torch::nn::functional::relu(x);
		x = fc3->forward(x);
		return x;
	}

private:
	torch::nn::Linear fc1{nullptr}, fc2{ nullptr }, fc3{ nullptr };
	torch::nn::BatchNorm1d bn1{ nullptr }, bn2{ nullptr };
};

auto model = std::make_shared<MODEL>();
model->eval();
auto input = torch::randn({ 1, 512 });
auto output = model->forward(input);

Calling eval() on the newly initialized batchnorm layers will not normalize the inputs at all since the initial running stats are used, where running_mean is initialized with zeros and running_var with ones.
In this case, only the affine weight and bias are used/trained and the layers can thus be replaced with plain linear layers.

Yes, it doesn’t make sense (either training batchnorm with batch size 1 or running inference on newly initialized model) but I thought the plan was to load something already trained …

I code conversion and loaded the weights for use with a pre-trained model on Python in C++…
Is there a difference from Python’s results if the batchNorm layer is removed from C++?

First, I solved the error!
Thank you all.