Defining redundant layers in the constructor affects the result

Hi all,

To be more specific about this question, I tested on the simple MNIST example. mnist. I also fixed the seed to ensure the reproducibility. But if I add one additional layer in the module constructor, and I didn’t use that layer in the forward function, e.g. as follows:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv3 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

I got different results. Why did that happen? Is that because different weight initialization was applied?
Thanks in advance!

Probably yes.
We’ve had a similar discussion some while ago and fixing the seed in such a setup to get the same initializations might be a bit tricky.
You could instead use the state_dict of one model to set all parameters of the other.
Since you are using dropout in your model, you would have to set the seed additionally.
Here is a small example using your model definition:

class NetA(nn.Module):
    def __init__(self):
        super(NetA, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

class NetB(nn.Module):
    def __init__(self):
        super(NetB, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv3 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
        
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

modelA = NetA()
modelB = NetB()
modelB.load_state_dict(modelA.state_dict(), strict=False)

# check some params
print((modelA.conv1.weight==modelB.conv1.weight).all())
print((modelA.conv2.weight==modelB.conv2.weight).all())
print((modelA.fc1.weight==modelB.fc1.weight).all())

x = torch.randn(1, 1, 28, 28)
torch.manual_seed(2809)
outputA = modelA(x)
torch.manual_seed(2809)
outputB = modelB(x)
print(outputA==outputB)
> tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.uint8)

Hi I also tried to use the same weights for initialization:

    def init_weights(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1 or classname.find('Linear') != -1:
            m.weight.data.normal_(0.0, 0.02)
            m.bias.data.fill_(0)
    model.apply(init_weights)

But still, the results were different. Are there any other reasons?

I assume the problem is about the order of making use of the sequence of generated pseudo random numbers. e.g. After generating a sequence of fake pseudo random number [1, 5, 4, 6, 7], conv1, conv2, dataloader may pick up 1,5,4 respectively; But if you added another layer, conv3, then conv1, conv2, conv3, dataloader will pick up 1,5,4,6 respectively. The base seed of dataloader will be different in these two situations.

2 Likes

Yeah, like I said, seeding is a bit tricky in such a situation. That is why in my example I’ve loaded the state_dict and also set the random seed for both forward passes to get the same result.
As @musicpiece explained, each additional call to the pseudo random number generator might yield a different result.
Since you are initializing an additional layer in your second model, the PRNG will get an additional call and the next random operation might behave differently (e.g. the Dropout layer might behave differently in both models). Also, if you call them sequentially like in my example without reseeding, Dropout will most likely behave differently.

1 Like