[solved] nn.Batchnorm1d() throws error though model.eval() is set when nn.Sequential() is used

I have the following simple sequential network which does classification,

input_dims = 36
output_dims = 16
hidden_dims = 4
num_hidden_layers = 4

class FC(nn.Module):
    def __init__(self):
        super(FC, self).__init__()
        lay = nn.Sequential(
                nn.Linear(input_dims, hidden_dims*input_dims),
                nn.BatchNorm1d(hidden_dims*input_dims),
                nn.ReLU())
        self.fc_layers = [lay]
        for i in range(0, num_hidden_layers):
            lay = nn.Sequential(
                    nn.Linear(hidden_dims*input_dims, hidden_dims*input_dims),
                    nn.BatchNorm1d(hidden_dims*input_dims),
                    nn.ReLU())
            self.fc_layers.append(lay)
        self.out = nn.Linear(hidden_dims*input_dims, output_dims)
        
        self._initialize_weights()
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        for lay in self.fc_layers:
            x = lay(x)
        return self.out(x)
        
    def _initialize_weights(self):
        relu_gain = nn.init.calculate_gain('relu')
        
        for i in range(0, len(self.fc_layers)):
            nn.init.xavier_normal(self.fc_layers[i][0].weight, relu_gain)
            
            nn.init.normal(self.fc_layers[i][1].weight, mean=1.0, std=0.02)
            self.fc_layers[i][1].bias.data.fill_(0)
        
        nn.init.xavier_normal(self.out.weight)
        
        print('Initialized weights')

model = FC()

During training, I’ve set model.train() and the following is the code I have for testing,

test_count = 0
correct = 0

for test_count in tqdm.tqdm(range(0, test_size)):
    model.eval()
    
    syn, target = Data.generate(batch_size = 1)
    syn = torch.Tensor(syn).view(1, -1)
    syn = Variable(syn)

    pred = model(syn)
    _, pred_class = torch.max(pred.data, 1)
    pred_class = pred_class.numpy()
    
    correct += (pred_class == target).sum()
    test_count += 1
    
    accuracy = 100*(correct/(test_count))

I get the following error,

---> 20     pred = model(syn)
     21     _, pred_class = torch.max(pred.data, 1)
     22     pred_class = pred_class.numpy()

ValueError: Expected more than 1 value per channel when training, got input size [1, 144]

Please note that I’ve set model.eval() during testing. I also want to mention that initially, I didn’t use nn.Sequential() and the model with same architecture didn’t give me any error. Please let me know if I’m missing anything.

The problem is that you put all your layers in a list (self.fc_layers) that is not a module. Calling eval() puts every module of your model to eval mode recursively. However self.fc_layers not being a module, it does not switch it to eval mode.
If you do not want to modify your model, you will have to call eval() on every module in self.fc_layers with:

model.eval()
for m in model.fc_layers:
    m.eval()

The other solution is to create a Sequential from self.fc_layers by calling

self.fc_layers = nn.Sequential(*self.fc_layers)

The rest does not need any modification. Although it then becomes faster to simply call x = self.fc_layers(x) instead of the for loop.

2 Likes

The first solution will not work well, because if model.fc_layers is an ordinary python list, then the parameters of its submodules are not listed in model.parameters().

Another solution is to use nn.ModuleList().

self.fc_layers = nn.ModuleList([lay])

self.fc_layers behaves like a python list, but its contents are properly registered with the model.

1 Like

Thank you for the information. Now I understood my mistake in the network and the reason why the performance drop has happened.