Conditional Batch Normalization?

Conditional Batch Normalization was proposed recently and a few recent work seems to suggest this has some interesting properties and give good performance in certain tasks. In this work, the authors implemented a variant of conditional BN in Tensorflow which learns a different scale and shift for each class.

I tried to implement it in PyTorch with the following code:

import torch
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter
from torch.nn import functional as F


# TODO: check contiguous in THNN
# TODO: use separate backend functions?
class _CondBatchNorm(Module):

    def __init__(self, num_features, n_labels, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super(_CondBatchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(torch.Tensor(n_labels, num_features))
            self.bias = Parameter(torch.Tensor(n_labels, num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
        self.reset_parameters()
        print(self.weight.size())

    def reset_parameters(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)
        if self.affine:
            self.weight.data.uniform_()
            self.bias.data.zero_()

    def _check_input_dim(self, input):
        return NotImplemented

    def forward(self, input, labels):
        self._check_input_dim(input)
        self.weight_per_sample= F.embedding(labels.long(), self.weight)
        self.bias_per_sample = F.embedding(labels.long(), self.bias)
        return F.batch_norm(
            input, self.running_mean, self.running_var, self.weight_per_sample, self.bias_per_sample ,
            self.training or not self.track_running_stats, self.momentum, self.eps)

    def __repr__(self):
        return ('{name}({num_features}, eps={eps}, momentum={momentum},'
                ' affine={affine}, track_running_stats={track_running_stats})'
                .format(name=self.__class__.__name__, **self.__dict__))

However, it seems that the low level torch._C.batch_norm can’t handle a weight/bias matrix (vs. a vector). In comparison, Tensorflow seems to support it well (here). I wonder if there is something wrong with the way I implement conditional BN, or Torch/Pytorch currently can’t support it?

1 Like

Hi @haoyangz By any chance, you found solution to this issue! I was also working on similar area and got stuck.

The most common/code efficient way of implementing conditional batch norm in PyTorch is to use a batch norm without weights and then applying weight and shift.

I once implemented it by redoing things for SNGAN.

Best regards

Thomas

8 Likes

Thanks Thomas, very helpful