AlexNet dimensions issue

I’m doing some tinkering with a modified AlexNet and adding in some BatchNorm to look at the position of batchnorm in relation to the activation function, and I’m getting a dimensions error, and I can’t seem to figure out where it’s coming from. I more or less copied the AlexNet architecture from the PyTorch code, but added in BatchNorm. Here’s my model class, train function, and training loop:

 class AlexNet_bn_first(nn.Module):
    def __init__(self):
        super(AlexNet_bn_first, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.BatchNorm2d(4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.BatchNorm2d(4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 6),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        
        return F.log_softmax(x, dim=1)

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

for epoch in range(num_epochs):
    train(model, device, train_loader, optimizer, epoch)
    test_loss = test(model, device, test_loader)
    is_best = test_loss < best_loss
    test_losses.append(test_loss)
    save_checkpoint({
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'best_loss': best_loss,
        'optimizer': optimizer.state_dict(),
    }, is_best)

And here’s the error traceback:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-16-495c1ea3f5ba> in <module>
      5 
      6 for epoch in range(1, num_epochs+1):
----> 7     train(model, device, train_loader, optimizer, epoch)
      8     test_loss = test(model, device, test_loader)
      9     is_best = test_loss < best_loss

<ipython-input-3-5c27374f04c3> in train(model, device, train_loader, optimizer, epoch, log_interval)
      4         data, target = data.to(device), target.to(device)
      5         optimizer.zero_grad()
----> 6         output = model(data)
      7         loss = F.nll_loss(output, target)
      8         loss.backward()

~\AppData\Local\Continuum\anaconda3\envs\detection\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

<ipython-input-12-199dfe1fc3cc> in forward(self, x)
     39         x = self.avgpool(x)
     40         x = torch.flatten(x, 1)
---> 41         x = self.classifier(x)
     42 
     43         return F.log_softmax(x, dim=1)

~\AppData\Local\Continuum\anaconda3\envs\detection\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

~\AppData\Local\Continuum\anaconda3\envs\detection\lib\site-packages\torch\nn\modules\container.py in forward(self, input)
     90     def forward(self, input):
     91         for module in self._modules.values():
---> 92             input = module(input)
     93         return input
     94 

~\AppData\Local\Continuum\anaconda3\envs\detection\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

~\AppData\Local\Continuum\anaconda3\envs\detection\lib\site-packages\torch\nn\modules\batchnorm.py in forward(self, input)
     59     @weak_script_method
     60     def forward(self, input):
---> 61         self._check_input_dim(input)
     62 
     63         # exponential_average_factor is self.momentum set to

~\AppData\Local\Continuum\anaconda3\envs\detection\lib\site-packages\torch\nn\modules\batchnorm.py in _check_input_dim(self, input)
    248         if input.dim() != 4:
    249             raise ValueError('expected 4D input (got {}D input)'
--> 250                              .format(input.dim()))
    251 
    252 

ValueError: expected 4D input (got 2D input)

It looks like the error is happening with the input to self.classifier. Indeed, I do flatten the tensor to 2D before it goes into the classifier, and when I check model.classifier, this is what I get:

Sequential(
  (0): Dropout(p=0.5)
  (1): Linear(in_features=9216, out_features=4096, bias=True)
  (2): BatchNorm2d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): ReLU(inplace)
  (4): Dropout(p=0.5)
  (5): Linear(in_features=4096, out_features=4096, bias=True)
  (6): BatchNorm2d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): ReLU(inplace)
  (8): Linear(in_features=4096, out_features=6, bias=True)
)

So, it should be expecting a 2D input, no? What am I missing here?

The classifier you defined uses 2d BatchNorm, which means it acts on the two spatial dimensions of an image, hence expecting a 4d tensor(batch x channels x height x width).
For tensors of the forms 3d tensors (batch x time-step x features) or 2d tensors (batch x features), you should use 1d BatchNorm
Modified code shown below

self.classifier = nn.Sequential( 
    nn.Dropout(), 
    nn.Linear(256 * 6 * 6, 4096), 
    nn.BatchNorm1d(4096), 
    nn.ReLU(inplace=True), 
    nn.Dropout(), 
    nn.Linear(4096, 4096), 
    nn.BatchNorm1d(4096), 
    nn.ReLU(inplace=True), 
    nn.Linear(4096, 6), 
    )
1 Like

d’oh! Thanks! I knew it would be something obvious like this. I kept looking at the input to the classifier, and not inside the classifier. Appreciate the help!

1 Like