How to implement conditional instance normalization

Hi

I’m going to implement conditional instance normalization. what I planned to do is to pass the weight and bias for each sample as a second argument to the forward function as follows.

class ConditionalInstanceNorm(Module):
    
    __constants__ = ['num_groups', 'num_channels', 'eps', 'affine']
    num_groups: int
    num_channels: int
    eps: float
    affine: bool

    def __init__(self, num_channels: int, eps: float = 1e-5, affine: bool = True) -> None:
        super(ConditionalInstanceNorm, self).__init__()
        self.num_groups = num_channels
        self.num_channels = num_channels
        self.eps = eps
        self.affine = affine
        self.register_parameter('weight', None)
        self.register_parameter('bias', None)        

    def forward(self, input: Tensor, condition) -> Tensor:
        weight = condition[:, :, 0]
        bias = condition[:, :, 1]
        return F.group_norm(
            input, self.num_groups, weight, bias, self.eps)

    def extra_repr(self) -> str:
        return '{num_groups}, {num_channels}, eps={eps}, ' \
            'affine={affine}'.format(**self.__dict__)

here is my question:
whether weight. and bias calculated for each sample in a batch independently when as a condition I pass a tensor of [bs, num_channels] instead of [num_channels] or not. I tried to follow the source code until here but can not understand it completely.

The parameters usually do not depend on the batch size, so you should make sure they have the shape [num_channels]:

norm = nn.GroupNorm(num_groups=3, num_channels=9)
print(norm.weight.shape)
> torch.Size([9])
print(norm.bias.shape)
> torch.Size([9])

Also, note that you are using self.weight and self.bias in your code, which are both set to None, so you might want to change it.

1 Like

Thanks for your reply. yes, they should be independent but as I’m implementing conditional instance norm, the weight and bias for each sample should be calculated from another sample as the condition. that is why I want to know if the F.group_norm() can handle weight and bias with size [bs, num_channels]. I think I can update the result of F.group_norm(bias=None, weight=None) with condition instead of passing as an argument to F.group_norm()

class ConditionalInstanceNorm(Module):
    
    __constants__ = ['num_groups', 'num_channels', 'eps', 'affine']
    num_groups: int
    num_channels: int
    eps: float
    affine: bool

    def __init__(self, num_channels: int, eps: float = 1e-5, affine: bool = True) -> None:
        super(ConditionalInstanceNorm, self).__init__()
        self.num_groups = num_channels
        self.num_channels = num_channels
        self.eps = eps
        self.affine = affine
        self.register_parameter('weight', None)
        self.register_parameter('bias', None)        

    def forward(self, input: Tensor, condition) -> Tensor:
        weight = condition[:, :, 0]
        bias = condition[:, :, 1]
        unnormalized =  F.group_norm(
            input, self.num_groups,  self.weight,  self.bias,  self.eps)
        weight = weight.unsqueeze(-1).unsqueeze(-1)
        bias = bias.unsqueeze(-1).unsqueeze(-1)
        normalized = (unnormalized*weight)  +  bias 
        return normalized 

    def extra_repr(self) -> str:
        return '{num_groups}, {num_channels}, eps={eps}, ' \
            'affine={affine}'.format(**self.__dict__)