SyncBatchNorm.convert_sync_batchnorm() causes ValueError: expected at least 3D input (got 2D input)

Not sure if something is missing but isn’t SyncBatchnorm.convert_sync_batchnorm() supposed to convert the module transparently?
However, the following code segment produce ValueError: expected at least 3D input (got 2D input) .
Without the conversion, the forward goes as expected.
Any ideas?

import os
import torch
from torch import nn

module = torch.nn.Sequential(
           torch.nn.Linear(20, 100),
           torch.nn.BatchNorm1d(100)
         ).cuda()

# creating process group (optional)
# process_ids is a list of int identifying rank ids.
os.environ['RANK'] = '0' 
os.environ['WORLD_SIZE'] = '1' 
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '25791'

process_group = torch.distributed.init_process_group(backend='nccl')
module = nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)

input = torch.randn(2, 20).cuda()
output = module(input)
print(output.shape)

The output:

Traceback (most recent call last):
  File "syncBN.py", line 21, in <module>
    output = module(input)
  File "/home/ml/farleylai/miniconda3/envs/sinet37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ml/farleylai/miniconda3/envs/sinet37/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/home/ml/farleylai/miniconda3/envs/sinet37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ml/farleylai/miniconda3/envs/sinet37/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 429, in forward
    self._check_input_dim(input)
  File "/home/ml/farleylai/miniconda3/envs/sinet37/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 417, in _check_input_dim
    .format(input.dim()))
ValueError: expected at least 3D input (got 2D input)

Expected output as w/o conversion:

torch.Size([2, 100])

Ubuntu 16.04 with PyTorch 1.3 installed through conda.

Hi,

The original modules like BatchNorm1d or BatchNorm2d support not having a batch size, so they handle respectively 2d/3d inputs and 3d/4d inputs.
The sync batchnorm has no specialized functions and works for all. But to know which version to use, it must use the number of dimensions of the input (otherwise as you see above, 3d input could be either a batched 1d or an unbatched 2d). And so it only allows having a batch dimension.

@albanD
I looked into the code and found this restriction is imposed by SyncBatchNorm:

def _check_input_dim(self, input):
        if input.dim() <= 2:
            raise ValueError('expected at least 3D input (got {}D input)'
                             .format(input.dim()))

This is completely different from the original BatchNorm1d to be wrapped:

    def _check_input_dim(self, input):
        if input.dim() != 2 and input.dim() != 3:
            raise ValueError('expected 2D or 3D input (got {}D input)'
                             .format(input.dim()))

I got confused with the code segment that is actually from your API document of SyncBatchNorm using BatchNorm1d for convert_sync_batchnorm.
Why doesn’t SyncBatchNorm explicitly check whether the module to wrap is BatchNorm1d or BatchNorm2d instead of the general _BatchNorm in convert_sync_batchnorm?
If this is not going to work, what is the right way to use convert_sync_batchnorm for those models with BatchNorm1d?

If this is not going to work, what is the right way to use convert_sync_batchnorm for those models with BatchNorm1d ?

I think the fix here is to ensure you always have a batch dimension. Potentially adding an .unsqueeze(0) to your input.

Then why doesn’t SyncBatchNorm explicitly check whether the module to wrap is BatchNorm1d or BatchNorm2d instead of the general _BatchNorm in convert_sync_batchnorm ?

This would be a nice addition, we would be happy to merge a PR that adds this feature!

I think the fix here is to ensure you always have a batch dimension. Potentially adding an .unsqueeze(0) to your input.

The example input (2, 20) already contains a batch dim, indicating a batch of two 1D examples.
If we fake the input with unsqueeze(0), how could it work when there are other modules before BatchNorm1d in the model that may assume the 0th dim must be the batch dim?
After all, the layer(s) before BatchNorm1d can be anything else in general, right?

BTW, I tried to make it of size (1, 2, 20) but it still complains something wrong with the running_mean size:

Traceback (most recent call last):
  File "syncBN.py", line 29, in <module>
    output = module(input)
  File "/home/ml/farleylai/miniconda3/envs/sinet37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ml/farleylai/miniconda3/envs/sinet37/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/home/ml/farleylai/miniconda3/envs/sinet37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ml/farleylai/miniconda3/envs/sinet37/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 459, in forward
    exponential_average_factor, self.eps)
  File "/home/ml/farleylai/miniconda3/envs/sinet37/lib/python3.7/site-packages/torch/nn/functional.py", line 1670, in batch_norm
    training, momentum, eps, torch.backends.cudnn.enabled
RuntimeError: running_mean should contain 2 elements not 100

This would be a nice addition, we would be happy to merge a PR that adds this feature!

So you are suggesting it is not intentional but something that could be completed?
If that is the case, it seems like creating an issue on the GitHub repo makes more sense and I will look into the details under the hood.

Therefore to sum up, there is likely no luck for those with BatchNorm1d already to be converted to SyncBN transparently by default.

1 Like

The example input (2, 20) already contains a batch dim, indicating a batch of two 1D examples.

That is not how batchnorm 1d works. Batchnorm 1d assumes an optional first batch dimension, then a channel dimension then an actual dimension. So the input is 2d without batch and 3d with batch.

BTW, I tried to make it of size (1, 2, 20) but it still complains something wrong with the running_mean size:

This is because you define your batchnorm as having 100 channels, but what you give as input has 2.

I am expiriencing the same problem as you @farleylai exactly now.

I am trying to run a model with ResNet backbone, which has only BatchNorm2d and a head network that have exactly ONE BatchNorm1d and that is exactly what causes problem.

The input to the BatchNorm1d in the forward function of the model is [64,2048].
As suggested by @albanD I unsqueezed it in the forward function) so that the input shape is now [64, 1, 2048]. Next module is a Linear classifier, so I squeezed the output of the BatchNorm1d to again have [64, 2048] input to Linear layer. This helped in the sense that the forward pass is working, but in the backward pass I am getting now an error:

RuntimeError: Function SyncBatchNormBackward returned an invalid gradient at index 1 - got [1] but expected shape compatible with [2048]

Any suggestions @albanD ?

Do you have a small code sample so that I can reproduce that locally?

Not really sure what you mean by ‘small code sample’. So I will try:

class Net(nn.Module):
    in_planes = 2048

    def __init__(self, num_classes, model_path,  model_name):
        super(Net, self).__init__()
        self.base = ResNet(block=Bottleneck,
                           layers=[3, 4, 6, 3])

        self.gap = nn.AdaptiveAvgPool2d(1)
        self.num_classes = num_classes
        self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.bias.requires_grad_(False)  # no shift
        self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)

    def forward(self, x):

        global_feat = self.gap(self.base(x))  # (b, 2048, 1, 1)
        global_feat = global_feat.view(global_feat.shape[0], -1)  # flatten to (bs, 2048)

        feat = self.bottleneck(global_feat.unsqueeze(1))  ### To allow SyncBatchnorm

        cls_score = self.classifier(feat.squeeze())   ### To adjust for Linear layer input

        return cls_score, global_feat 


def train():
        model.train()
        optimizer.zero_grad()
        img, target = batch
        img = img.to(device)
        target = target.to(device)t
        score, feat = model(img)
        LOSS = loss_fn(score, feat, target)
        LOSS.backward()
        optimizer.step()

I shortened the code as much as possible to get the most important parts I think. The img size (input in train function) is [batch_size, 3, W, H] so standard for images.

EDIT: format and making a little clearer code

I think I got it right now.

According to docs torch.nn — PyTorch 2.1 documentation and quoting:

Parameters

  • num_features – C from an expected input of size (N,C,L) or L from input of size (N,L)

So I figured out that as we need 3d input and BatchNorm1d uses C as the num_features in three dimensional input, the singular dimension should be the last one.
So instead of

feat = self.bottleneck(global_feat.unsqueeze(1)) # Which gives [bs, 1, 2048]
I just did:
feat = self.bottleneck(global_feat.unsqueeze(-1)) # Which gives [bs, 2048, 1]

No more errors and training seems to run smoothly with SyncBatchNorm as well. Hope, this helps someone.

3 Likes

That is not how batchnorm 1d works. Batchnorm 1d assumes an optional first batch dimension, then a channel dimension then an actual dimension. So the input is 2d without batch and 3d with batch.

As defined by the BatchNorm1d, the Input is expected to be of size (N, L) or (N, C, L) with batch dim first. What is optional is the additional channel dimension for BatchNorm1d from the documentation.

This is because you define your batchnorm as having 100 channels, but what you give as input has 2 .

(1, 2, 20) is due to the suggestion adding .unsqueeze(0) to your input but the resulting shape is not originally intended. By definition, whether the 100 is C or L in the previous example, BatchNorm1d produces the same results given (N, 100) or (N, 100, 1). (2, 100) is already a batch input with 2 1D features and matches the input accepted by BatchNorm1d. This has to be on the same page.

Now, get back to the issue with SyncBatchNorm conversion. Two questions:

  1. Does SyncBatchNorm wrapped BatchNorm1d behave as expected as before the conversion?

The original BatchNorm1d takes both (N, L) or (N, C, L) and produces the same results as the following revised code segment shows. However, after converted to SyncBatchNorm which CHANGES the interface to ONLY accepts input of size (N, C, L). This conversion unlikely works transparently with existing models using BatchNorm1d to accept input of size (N, L).

import os
import copy
import torch
from torch import nn

with torch.no_grad():
    inputNL = torch.randn(2, 20).cuda()
    module = torch.nn.Sequential(
               torch.nn.Linear(20, 100),
               torch.nn.BatchNorm1d(100)
             ).cuda()
    moduleC = copy.deepcopy(module).cuda()
    moduleL = copy.deepcopy(module).cuda()
    moduleC.eval()
    moduleL.eval()

    # XXX: BatchNorm1d accepts (N, C, L) 
    outputNL = moduleC[0](inputNL)
    outputNCL = moduleC[1](outputNL.unsqueeze(-1))
    print('BatchNorm1d NCL:', outputNCL.shape, round(outputNCL.mean().item(), 7))

    # XXX: BatchNorm1d accepts (N, L) too
    outputNL = moduleL[0](inputNL)
    outputNL = moduleL[1](outputNL)
    print('BatchNorm1d NL:', outputNL.shape, round(outputNL.mean().item(), 7))

    os.environ['RANK'] = '0'
    os.environ['WORLD_SIZE'] = '1'
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '25791'
    torch.distributed.init_process_group(backend='nccl')

    moduleC = copy.deepcopy(module)
    moduleL = copy.deepcopy(module)
    moduleC = nn.SyncBatchNorm.convert_sync_batchnorm(moduleC)
    moduleL = nn.SyncBatchNorm.convert_sync_batchnorm(moduleL)
    moduleC.eval()
    moduleL.eval()

    # XXX: converted BatchNorm1d ONLY accepts (N, C, L) 
    outputNL = moduleC[0](inputNL)
    outputNCL = moduleC[1](outputNL.unsqueeze(-1))
    print('SyncBatchNorm NCL:', outputNCL.shape, round(outputNCL.mean().item(), 7))

    # FIXME: Converted BatchNorm1d never accepts (N, L)
    outputNL = moduleL[0](inputNL)
    outputNL = moduleL[1](outputNL)
    print('SyncBatchNorm NL:', outputNL.shape, round(outputNL.mean().item(), 7))

Sample output:

BatchNorm1d NCL: torch.Size([2, 100, 1]) 0.0683341
BatchNorm1d NL: torch.Size([2, 100]) 0.0683341
SyncBatchNorm NCL: torch.Size([2, 100, 1]) 0.0683341
Traceback (most recent call last):
  File "syncBN.py", line 45, in <module>
    outputNL = moduleL[1](outputNL)
  File "/home/ml/farleylai/miniconda3/envs/sinet37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ml/farleylai/miniconda3/envs/sinet37/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 429, in forward
    self._check_input_dim(input)
  File "/home/ml/farleylai/miniconda3/envs/sinet37/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 417, in _check_input_dim
    .format(input.dim()))
ValueError: expected at least 3D input (got 2D input)
  1. If not, what is the justification or workaround that does not require changing the existing model to wrap?

One workaround is to reshape/unsqueeze(-1) the immediate input of size (N, L) to (N, C=L, L=1) before the converted BatchNorm1d as demonstrated by @bonzogondo. Unfortunately, this may not be scalable if the uses of BatchNorm1d are all over the place in existing models. There is no reshape layers in PyTorch to automate the unsqeeze. An alternative could be to identify whether the BatchNorm to wrap is 1D or not so that the SyncBatchNorm._check_input_dim(…) checks the same criteria as BatchNorm1d as sketched in the following. There may be some other exceptions but the goal should be to wrap existing models transparently.

class SyncBatchNorm(nn.SyncBatchNorm):
    def _check_input_dim(self, input):
        if self._1d:
            if input.dim() != 2 and input.dim() != 3:
                raise ValueError('expected 2D or 3D input (got {}D input)'
                                    .format(input.dim()))
        elif input.dim() <= 2:
            raise ValueError('expected at least 3D input (got {}D input)'
                             .format(input.dim()))

    @classmethod
    def convert_sync_batchnorm(cls, module, process_group=None):
        ...
        if isinstance(module, nn.modules.batchnorm._BatchNorm):
            module_output = SyncBatchNorm(module.num_features,
                                              module.eps, module.momentum,
                                              module.affine,
                                              module.track_running_stats,
                                              process_group)
            module_output._1d = isinstance(module, nn.modules.batchnorm.BatchNorm1d)
            ...
1 Like

I faced the same problem. Any update?