How to set different weight initialization parameters for each layers?

Hello,

I define the following function to initialize the weights of my network of different layers.

l have 5 different convolutional layers of the same dimensions.
And 3 different linear layers of the same dimensions.

Is this way of initializing the network ensures that all the layers have different weights initially ?

Weight(conv1) not equal Weight(conv2) not equal Weight(conv3) not equal Weight(conv4) not equal Weight(conv5)

and

Weight(linear1) not equal Weight(linear2) not equal Weight(linear3)

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    if classname.find('Linear') != -1:
        # get the number of the inputs
        n = m.in_features
        y = 1.0 / np.sqrt(n)
        m.weight.data.uniform_(-y, y)
        m.bias.data.fill_(0)
    elif classname.find('BatchNorm') != -1:
        m.normal_(m.weight.data, mean=1, std=0.02)
        m.constant_(m.bias.data, 0)

model=ConvNet()
model.apply(weights_init)
1 Like

You can remove all the .data and replace them with:

@torch.no_grad()
def weights_init(m):
  # Your code

And yes this will reinitialize all the weights with random values.
You might be interested by the torch.nn.init package that gives you many common initialization methods.

1 Like

Thank you for your answer @albanD,

Is it right ?

@torch.no_grad()
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.normal_(0.0, 0.02)
    if classname.find('Linear') != -1:
        # get the number of the inputs
        n = m.in_features
        y = 1.0 / np.sqrt(n)
        m.weight.uniform_(-y, y)
        m.bias.fill_(0)

    elif classname.find('BatchNorm') != -1:
        m.normal_(m.weight, mean=1, std=0.02)
        m.constant_(m.bias, 0)

is @torch.no_grad() different from torch.no_grad() ?

What is wrong with .data ?

The two are the same, one is a context manager, the other is a function decorator. It is the same as running your whole function in with torch.no_grad(). So it is quite convenient.

The problem is that .data can hide some errors and give you wrong gradients.
For example, this issue poped up today: https://github.com/pytorch/pytorch/issues/30073 which is caused by the use of .data in the old codebase.

1 Like