Weight initialization with a custom method in nn.Sequential

Hi, I am trying to initialize the weights of a conv net (with nn.Sequential) using a custom method.

When I do this initialization my network achieves an accuracy equal to ~10% (for CIFAR-10 this is equivalent to a random response => the network doesn’t learn anything). Without this initialization I get ~58% accuracy (the conv net can learn without this init).

I am sure that I am doing something wrong but I don’t know where is the problem. I would like to initialize the weights using the weights_init, random_weight, zero_weight methods. Any help/advice is welcome :slight_smile: . Thanks.

The code (some code is from the cs231 course from Stanford):

def random_weight(shape):
    Kaiming normalization: sqrt(2 / fan_in)
    if len(shape) == 2:  # FC weight
        fan_in = shape[0]
        fan_in = np.prod(shape[1:]) # conv weight [out_channel, in_channel, kH, kW]

    w = torch.randn(shape, device=device, dtype=dtype) * np.sqrt(2. / fan_in)
    w.requires_grad = True
    return w

def zero_weight(shape):
    return torch.zeros(shape, device=device, dtype=dtype, requires_grad=True)
def weights_init(m):
    if type(m) in [nn.Conv2d, nn.Linear]:
        m.weight.data = random_weight(m.weight.data.size())
        m.bias.data = zero_weight(m.bias.data.size())

class Flatten(nn.Module):
    def forward(self, x):
        return flatten(x)

model = nn.Sequential(
    nn.Conv2d(in_channel, channel_1, (5, 5), padding=2),
    nn.Conv2d(channel_1, channel_2, (3, 3), padding=1),
    nn.Linear(channel_2 * 32 * 32, num_classes)


optimizer = optim.SGD(model.parameters(), lr=learning_rate,
                      momentum=0.9, nesterov=True)

I’ve tested your code for some random input and the model could fit it.
Using CIFAR10, I just set the channels to 6 and 12 in the conv layers.
The model indeed struggles to learn and the default random init seems to be better.
Changing fan_in = shape[1] for linear layers makes it a bit better.
I don’t think you have a code bug, but probably your current custom initialization does not provide the benefit you expect.