How to implement conditional instance normalization

Hi

I’m going to implement conditional instance normalization. what I planned to do is to pass the weight and bias for each sample as a second argument to the forward function as follows.

``````class ConditionalInstanceNorm(Module):

__constants__ = ['num_groups', 'num_channels', 'eps', 'affine']
num_groups: int
num_channels: int
eps: float
affine: bool

def __init__(self, num_channels: int, eps: float = 1e-5, affine: bool = True) -> None:
super(ConditionalInstanceNorm, self).__init__()
self.num_groups = num_channels
self.num_channels = num_channels
self.eps = eps
self.affine = affine
self.register_parameter('weight', None)
self.register_parameter('bias', None)

def forward(self, input: Tensor, condition) -> Tensor:
weight = condition[:, :, 0]
bias = condition[:, :, 1]
return F.group_norm(
input, self.num_groups, weight, bias, self.eps)

def extra_repr(self) -> str:
return '{num_groups}, {num_channels}, eps={eps}, ' \
'affine={affine}'.format(**self.__dict__)
``````

here is my question:
whether weight. and bias calculated for each sample in a batch independently when as a condition I pass a tensor of `[bs, num_channels]` instead of` [num_channels]` or not. I tried to follow the source code until here but can not understand it completely.

The parameters usually do not depend on the batch size, so you should make sure they have the shape `[num_channels]`:

``````norm = nn.GroupNorm(num_groups=3, num_channels=9)
print(norm.weight.shape)
> torch.Size([9])
print(norm.bias.shape)
> torch.Size([9])
``````

Also, note that you are using `self.weight` and `self.bias` in your code, which are both set to `None`, so you might want to change it.

1 Like

Thanks for your reply. yes, they should be independent but as I’m implementing conditional instance norm, the `weight` and `bias` for each sample should be calculated from another sample as the condition. that is why I want to know if the `F.group_norm()` can handle weight and bias with size `[bs, num_channels].` I think I can update the result of `F.group_norm(bias=None, weight=None)` with condition instead of passing as an argument to `F.group_norm()`

``````class ConditionalInstanceNorm(Module):

__constants__ = ['num_groups', 'num_channels', 'eps', 'affine']
num_groups: int
num_channels: int
eps: float
affine: bool

def __init__(self, num_channels: int, eps: float = 1e-5, affine: bool = True) -> None:
super(ConditionalInstanceNorm, self).__init__()
self.num_groups = num_channels
self.num_channels = num_channels
self.eps = eps
self.affine = affine
self.register_parameter('weight', None)
self.register_parameter('bias', None)

def forward(self, input: Tensor, condition) -> Tensor:
weight = condition[:, :, 0]
bias = condition[:, :, 1]
unnormalized =  F.group_norm(
input, self.num_groups,  self.weight,  self.bias,  self.eps)
weight = weight.unsqueeze(-1).unsqueeze(-1)
bias = bias.unsqueeze(-1).unsqueeze(-1)
normalized = (unnormalized*weight)  +  bias
return normalized

def extra_repr(self) -> str:
return '{num_groups}, {num_channels}, eps={eps}, ' \
'affine={affine}'.format(**self.__dict__)
``````