Hi
I’m going to implement conditional instance normalization. what I planned to do is to pass the weight and bias for each sample as a second argument to the forward function as follows.
class ConditionalInstanceNorm(Module):
__constants__ = ['num_groups', 'num_channels', 'eps', 'affine']
num_groups: int
num_channels: int
eps: float
affine: bool
def __init__(self, num_channels: int, eps: float = 1e-5, affine: bool = True) -> None:
super(ConditionalInstanceNorm, self).__init__()
self.num_groups = num_channels
self.num_channels = num_channels
self.eps = eps
self.affine = affine
self.register_parameter('weight', None)
self.register_parameter('bias', None)
def forward(self, input: Tensor, condition) -> Tensor:
weight = condition[:, :, 0]
bias = condition[:, :, 1]
return F.group_norm(
input, self.num_groups, weight, bias, self.eps)
def extra_repr(self) -> str:
return '{num_groups}, {num_channels}, eps={eps}, ' \
'affine={affine}'.format(**self.__dict__)
here is my question:
whether weight. and bias calculated for each sample in a batch independently when as a condition I pass a tensor of [bs, num_channels]
instead of [num_channels]
or not. I tried to follow the source code until here but can not understand it completely.
The parameters usually do not depend on the batch size, so you should make sure they have the shape [num_channels]
:
norm = nn.GroupNorm(num_groups=3, num_channels=9)
print(norm.weight.shape)
> torch.Size([9])
print(norm.bias.shape)
> torch.Size([9])
Also, note that you are using self.weight
and self.bias
in your code, which are both set to None
, so you might want to change it.
1 Like
Thanks for your reply. yes, they should be independent but as I’m implementing conditional instance norm, the weight
and bias
for each sample should be calculated from another sample as the condition. that is why I want to know if the F.group_norm()
can handle weight and bias with size [bs, num_channels].
I think I can update the result of F.group_norm(bias=None, weight=None)
with condition instead of passing as an argument to F.group_norm()
class ConditionalInstanceNorm(Module):
__constants__ = ['num_groups', 'num_channels', 'eps', 'affine']
num_groups: int
num_channels: int
eps: float
affine: bool
def __init__(self, num_channels: int, eps: float = 1e-5, affine: bool = True) -> None:
super(ConditionalInstanceNorm, self).__init__()
self.num_groups = num_channels
self.num_channels = num_channels
self.eps = eps
self.affine = affine
self.register_parameter('weight', None)
self.register_parameter('bias', None)
def forward(self, input: Tensor, condition) -> Tensor:
weight = condition[:, :, 0]
bias = condition[:, :, 1]
unnormalized = F.group_norm(
input, self.num_groups, self.weight, self.bias, self.eps)
weight = weight.unsqueeze(-1).unsqueeze(-1)
bias = bias.unsqueeze(-1).unsqueeze(-1)
normalized = (unnormalized*weight) + bias
return normalized
def extra_repr(self) -> str:
return '{num_groups}, {num_channels}, eps={eps}, ' \
'affine={affine}'.format(**self.__dict__)