Batch norm training mode error despite model.eval()

Hello. I am working on a TabNet model. I encounter issues when I wanted to perform real-time prediction on a single input data point (batch_size = 1). Despite specifying model.eval() it still throws out the following error:
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 128])

This is the Ghost Batch Normalization method that I am using:

class GhostBatchNorm(torch.nn.Module):
    def __init__(self, num_features, n_chunks=128, chunk_size=0, track_running_stats=False, momentum=0.02):
        
        super(GhostBatchNorm, self).__init__()
        
        self.num_features = num_features
        self.n_chunks = n_chunks
        self.chunk_size = chunk_size 
        self.trs = track_running_stats
        self.momentum = momentum
        
        self.batch_norm = torch.nn.BatchNorm1d(self.num_features, 
                                         track_running_stats=self.trs,
                                         momentum=self.momentum)
        
        print('batch training: ', self.batch_norm.training)
        
    def forward(self, batch):
         
        if self.chunk_size:
            self.n_chunks = batch.size(0) // self.chunk_size
            
        batch_chunks = torch.chunk(batch, self.n_chunks, dim=0)
        
        print('batch training: ', self.batch_norm.training)
        
        norm_batch_chunks = [self.batch_norm(chunk) for chunk in batch_chunks]
        
        return torch.cat(norm_batch_chunks, dim=0)

I also noticed some strange behaviour, namely when the __init__ is called the ```
self.batch_norm.training = True (regardless model.eval()), but when the forward method is called the same attribute returns False. However when the next line is executed:
norm_batch_chunks = [self.batch_norm(chunk) for chunk in batch_chunks]
it raises the aforementioned error (despite self.batch_norm.training = False)

After calling model.eval() I also printed whether the batch_norm is in the training mode:
print('Model training: ', model.attentive_transformer[0].gbn.batch_norm.training)
and the returned value was False.

I am really perplexed. Why the initialization of the batch_norm inside the GhostBatchNorm class returns training = True and why despite model.eval() the error is raised.

I investigated a little more the source of the error:
file: batchnorm.py

class _BatchNorm(_NormBase):

    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super(_BatchNorm, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

    def forward(self, input: Tensor) -> Tensor:
        self._check_input_dim(input)

        # exponential_average_factor is set to self.momentum
        # (when it is available) only so that it gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
            if self.num_batches_tracked is not None:
                self.num_batches_tracked = self.num_batches_tracked + 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        """ Decide whether the mini-batch stats should be used for normalization rather than the buffers.
        Mini-batch stats are used in training mode, and in eval mode when buffers are None.
        """
        print('BatchNorm training: ', self.training)
        if self.training:
            bn_training = True
        else:
            bn_training = (self.running_mean is None) and (self.running_var is None)

        """Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
        passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
        used for normalization (i.e. in eval mode when buffers are not None).
        """

        return F.batch_norm(
            input,
            # If buffers are not to be tracked, ensure that they won't be updated
            self.running_mean if not self.training or self.track_running_stats else None,
            self.running_var if not self.training or self.track_running_stats else None,
            self.weight, self.bias, bn_training, exponential_average_factor, self.eps)

I added here this line: print('BatchNorm training: ', self.training) which always returns False

Second file: functional.py

def batch_norm(input, running_mean, running_var, weight=None, bias=None,
               training=False, momentum=0.1, eps=1e-5):
    # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, float, float) -> Tensor  # noqa
    r"""Applies Batch Normalization for each channel across a batch of data.

    See :class:`~torch.nn.BatchNorm1d`, :class:`~torch.nn.BatchNorm2d`,
    :class:`~torch.nn.BatchNorm3d` for details.
    """
    if not torch.jit.is_scripting():
        if type(input) is not Tensor and has_torch_function((input,)):
            return handle_torch_function(
                batch_norm, (input,), input, running_mean, running_var, weight=weight,
                bias=bias, training=training, momentum=momentum, eps=eps)
    print('functional BatchNorm training: ', training)
    if training:
        _verify_batch_size(input.size())

    return torch.batch_norm(
        input, weight, bias, running_mean, running_var,
        training, momentum, eps, torch.backends.cudnn.enabled
    )

print('functional BatchNorm training: ', training) in this file it returns True

So _BatchNorm.forward (with self.training = False) calls F.batch_norm where suddenly training = True. I don’t understand such behaviour.

Culprit found:
These lines of the code from batchnorm.py class _BatchNorm cause the error:

        if self.training:
            bn_training = True
        else:
            bn_training = (self.running_mean is None) and (self.running_var is None)
  • Because of self.training = False the else statement is called.
  • track_running_stats = False, and self.running_mean and self.running_var are None
  • (self.running_mean is None) and (self.running_var is None) is evaluated as True, so bn_training = True
  • bn_training is passed to F.batch_norm as training = bn_training = True
  • F.batch_norm training = True and _verify_batch_size(input.size()) is called

Is this correct behaviour that when track_running_stats = False the training mode of a module is compulsory? What can I change to force it to work properly? I had to use track_running_stats = False because otherwise, the evaluation loss is huge!

Any help would be very appreciated.

Radek

1 Like

Yes, since the batchnorm layer cannot use the running stats to normalize the input (they are disabled) and needs to compute the input batch stats.