Try to understand instance normalization, weight normalization, layer normalizationand and group normalizationinstead of batch normalization

Hi all,

I have a question concerning how to use instance normalization, weight norm, layer norm and group norm instead of batch normalization.

Can someone explain to me please how to replace the batchnorm by the others normalization in the following example, just to understand better how it works.

Thanks a lot for your help.

Here is the example :

class ModelExample(torch.nn.Module):

def __init__(self):
super(ModelExample, self).__init__()

self.linear = torch.nn.Linear(100, 1024*4*4)

self.conv1 = nn.Sequential(
                           nn.ConvTranspose2d(in_channels=1024, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False)
self.conv2 = nn.Sequential(
                           nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=4, stride=2, padding=1, bias=False)
self.out = torch.nn.Tanh()

def forward(self, x):

x = self.linear(x)
x = x.view(x.shape[0], 1024, 4, 4)

x = self.conv1(x)
x = self.conv2(x)

x = self.out(x)

return x