Model doesn't train after custom weight initialization

I have a model that trains fine without weight initialization but when I try to initialize the weights of my model to custom values (generated by custom_weight_fn), it stops training. Here is the code:

def init_weights(m):
  if type(m) == nn.Linear or type(m) == nn.Conv2d:
    m.weight.data=custom_weight_fn(m.weight.shape)
    m.bias.data=custom_weight_fn(m.bias.shape)

model = nn.Sequential(
    nn.Conv2d(3, channel_1, 5, padding=2),
    nn.ReLU(),
    nn.Conv2d(channel_1, channel_2, 3, padding=1),
    nn.ReLU(),
    Flatten(),
    nn.Linear(32*32*channel_2, 10)
)

model.apply(init_weights)

Does anybody know why this is not working? The code runs without any errors but the model doesn’t train.

Thanks!

Don’t use the .data attribute as it’s deprecated and could yield any side effects.
Instead, wrap the code in a with torch.no_grad() block and assign new nn.Parameters to the module parameters or use m.param.copy_ instead.

1 Like

Thanks for your suggestion. But when I apply these changes, it still doesn’t train… Here is my updated code:

def init_weights(m):
  with torch.no_grad():
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
      m.weight = nn.Parameter(custom_weight_fn(m.weight.shape))
      m.bias = nn.Parameter(custom_weight_fn(m.bias.shape))

model = nn.Sequential(
    nn.Conv2d(3, channel_1, 5, padding=2),
    nn.ReLU(),
    nn.Conv2d(channel_1, channel_2, 3, padding=1),
    nn.ReLU(),
    Flatten(),
    nn.Linear(32*32*channel_2, 10)
)

model.apply(init_weights) #if I comment this line, the model trains with no problem

I also tried m.param.copy_ and unfortunately that doesn’t work either.

In that case I guess your custom init method might yield a bad initialization, as using e.g. torch.randn works fine and creates valid gradients:

def custom_weight_fn(shape):
    return torch.randn(shape)


def init_weights(m):
  with torch.no_grad():
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
      m.weight = nn.Parameter(custom_weight_fn(m.weight.shape))
      m.bias = nn.Parameter(custom_weight_fn(m.bias.shape))

channel_1 = 3
channel_2 = 3
model = nn.Sequential(
    nn.Conv2d(3, channel_1, 5, padding=2),
    nn.ReLU(),
    nn.Conv2d(channel_1, channel_2, 3, padding=1),
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(32*32*channel_2, 10)
)

model.apply(init_weights) 

out = model(torch.randn(1, 3, 32, 32))
out.mean().backward()

for name, param in model.named_parameters():
    print('{}, grad.abs().max(): {}'.format(name, param.grad.abs().max()))

You are right! I had a small issue, somehow when I initialize my weights according to the custom init function, the training becomes much harder, but if I do hyper-parameter tuning then the model starts to train. Your suggestion on how to set the weights to custom values works well.