ResNet reproducibility

Hi everyone :slight_smile:

I have two models that are essentially the same (same architecture, same number of parameters) but they yield different results. The first model is one from the PyTorch model selection (a ResNet18 without pretrained weights) and the other one is essentially copy pasted code a bit reformatted (I want to later try some stuff with the ResNet architecture which is why I had to code it myself).

Somehow they yield different results even if I seed my code to make it reproducible…Does anyone know why that is the case? Is it because altough conceptually they are the same, PyTorch initialises different weights for them because they are different instances?

Any help is very much appreciated!

All the best
snowe

Edit: I just realised that when I instantiate two ResNets from PyTorch they also yield different results, even though they are the same. Is that behaviour to be expected? Is this because of some randomness in the batchnorm or so?

You should be able to get the same results between runs, at least I’m pretty sure of that. What did you set seed on? I tried with

import torch
import torchvision.models as models
seed=42
torch.manual_seed(seed)
resnet18 = models.resnet18(pretrained=False)
x = torch.randn((1, 3, 224, 224))
print(resnet18(x)[0:10])

And I seem to get same results on different runs. If you are running on CUDA then you should add

torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

When training I think it is also important to send seed to DataLoader. Perhaps you could share some small code that give you different results?

Hi @AladdinPerzon, thank you for your response!

I seed like this:

torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed) 
# for cuda
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False

I am able to reproduce the results if I load the ResNet from PyTorch and use this one over and over again. But when I use my own implementation of the ResNet with the same architecture and the same number of parameters I don’t get the same results as I did with the loaded one.

In other words, loading and own implementation yield different results although the network is essentially the same…

Ah, got it. It seems to me that it must weight initialization which seems strange since we’ve set the seed. This seems a bit odd and is indeed what I also get:

Edit: I just realised that when I instantiate two ResNets from PyTorch they also yield different results, even though they are the same.

Maybe someone can clarify this :slight_smile:

I cannot reproduce this issue:

torch.manual_seed(2809)

modelA = models.resnet18()
modelB = models.resnet18()
modelB.load_state_dict(modelA.state_dict())

x = torch.randn(8, 3, 224, 224)

outA = modelA(x)
outB = modelB(x)
print((outA - outB).abs().max())
> tensor(0., grad_fn=<MaxBackward1>)

Could you clarify your use case or post a code snippet to produce this issue?

2 Likes

If I understand correctly, topic author is trying to understand why his own implementation and implementation of resnet18 from torchvision give different results. First, i would try to catch some obvious mistakes like instantiating pretrained version from torchvision. Second, I would go through the source code of torchvision model and make sure I did all the layer instantiations and weights initialization in my own model the same as they do in ‘reference’ model. Without looking into ‘own’ model code it is to broad topic to discuss, right?

1 Like

That’s right and a good idea to solve the original issue. I would start with making sure the torchvision implementation yields the same results first as a smoke test, as I guess there might be a misunderstanding how seeding works or another issue.

1 Like

Could you explain why we are doing this line:

modelB.load_state_dict(modelA.state_dict())

If I do

seed=42
torch.manual_seed(seed)
resnetA = models.resnet18(pretrained=False)
resnetB = models.resnet18(pretrained=False)
x = torch.randn((1, 3, 224, 224))
print((resnetA(x) - resnetB(x)).abs().max())

I obtain different results. I’m assuming this is expected but I guess why they are different is not clear to me

Your code should yield different results, since you are only seeding the code once.
Each call into the pseudorandom number generator would yield a new random number.
Seeding will make sure that the sequence of these pseudorandom numbers is reproducible, but won’t yield the same numbers for random calls:

torch.manual_seed(2809)
print(torch.randn(2))
> tensor([-2.0748,  0.8152])
print(torch.randn(2))
> tensor([-1.1281,  0.8386])

torch.manual_seed(2809)
print(torch.randn(2))
> tensor([-2.0748,  0.8152])
print(torch.randn(2))
> tensor([-1.1281,  0.8386])

If you want to use the seed to initialize both models with the same random values, you would have to re-seed the code:


torch.manual_seed(2809)
modelA = models.resnet18()
torch.manual_seed(2809)
modelB = models.resnet18()


x = torch.randn(8, 3, 224, 224)

outA = modelA(x)
outB = modelB(x)
print((outA - outB).abs().max())
> tensor(0., grad_fn=<MaxBackward1>)
1 Like

Great explanation! Thanks ))

1 Like

That makes sense, thank you :pray:

1 Like

Thank you all for your inputs! :slight_smile:

So I can reproduce the results of two ResNets I load from PyTorch. However, my own implementation still yields different results compared to the loaded ones and I’ve checked it multiple times and cannot figure out why it behaves different. I can reproduce the results of my own implementation as well, so there does not seem to be some weird random stuff happening…

Here is my ResNet implementation:

# this type of block is used to build ResNet18 and ResNet34
class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        self.relu = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.downsample = downsample
        self.stride = stride
        
    def forward(self, x):
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        if self.downsample is not None:
            identity = self.downsample(identity)
            
        out += identity
        out = self.relu(out)
        
        return out
    
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes):
        super(ResNet, self).__init__()
        
        self.in_channels = 64 
        
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.in_channels, kernel_size=7, 
                               padding=3, stride=2, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, padding=1, stride=2)
        
        self.layer1 = self._make_layer(block, layers[0], out_channels=64, stride=1)
        self.layer2 = self._make_layer(block, layers[1], out_channels=128, stride=2)
        self.layer3 = self._make_layer(block, layers[2], out_channels=256, stride=2)
        self.layer4 = self._make_layer(block, layers[3], out_channels=512, stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def _make_layer(self, block, num_blocks, out_channels, stride):
        downsample = None
        
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels*block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels*block.expansion)
            )
        
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels * block.expansion
        for _ in range(1, num_blocks):
            layers.append(block(self.in_channels, out_channels))
            
        return nn.Sequential(*layers)
    
    def forward(self, x):
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        return x

And here how I test it:

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed) 
    # for cuda
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False

set_seed(0)
modelA = PyTorchModels.resnet18()
in_ = modelA.fc.in_features
classes = 10
modelA.fc = nn.Linear(in_features=in_, out_features=classes)

set_seed(0)
modelB = PyTorchModels.resnet18()
in_ = modelB.fc.in_features
classes = 10
modelB.fc = nn.Linear(in_features=in_, out_features=classes)

set_seed(0)
modelC = ResNet(BasicBlock, [2, 2, 2, 2], 10)

t = torch.rand(32, 3, 32, 32)

outA = modelA(t)
outB = modelB(t)
outC = modelC(t)

print(outA[0])
print('\n')
print(outB[0])
print('\n')
print(outC[0])

tensor([-0.0160, -0.0413,  0.5379, -0.3654, -0.0620, -0.7079, -0.9632, -0.9346,
         1.5941,  1.0369], grad_fn=<SelectBackward>)

tensor([-0.0160, -0.0413,  0.5379, -0.3654, -0.0620, -0.7079, -0.9632, -0.9346,
         1.5941,  1.0369], grad_fn=<SelectBackward>)

tensor([ 0.1272,  0.1153, -0.4902, -0.2696, -0.4524, -0.4243, -0.5799, -0.0227,
         0.5023,  0.8597], grad_fn=<SelectBackward>)

So the two loaded ResNets behave the same but diffreent to my own…

I get the same results, if I try to make sure to use the same calls into the PRNG:

torch.manual_seed(2809)
modelA = ResNet(BasicBlock, [2, 2, 2, 2], 1000)
in_ = modelA.fc.in_features
classes = 10
modelA.fc = nn.Linear(in_features=in_, out_features=classes)
torch.manual_seed(2809)
modelB = ResNet(BasicBlock, [2, 2, 2, 2], 1000)
in_ = modelB.fc.in_features
classes = 10
modelB.fc = nn.Linear(in_features=in_, out_features=classes)

torch.manual_seed(2809)
modelC = models.resnet18()
in_ = modelC.fc.in_features
modelC.fc = nn.Linear(in_features=in_, out_features=classes)


x = torch.randn(8, 3, 224, 224)

outA = modelA(x)
outB = modelB(x)
outC = modelC(x)
print((outA - outB).abs().max())
> tensor(0., grad_fn=<MaxBackward1>)

print((outA - outC).abs().max())
> tensor(0., grad_fn=<MaxBackward1>)

In your example you are using modelC = ResNet(BasicBlock, [2, 2, 2, 2], 10), which will directly create a linear layer with 10 output classes.
While it’s a valid approach for your use case, this will break the comparison using seeds, since the calls to the PRNG in the torchvision implementation are:

-> init layer1
-> init layer2
...
-> init fc with 1000 output classes
-> init custom nn.Linear with 10 output classes

while you would skip the penultimate step.
If you use my code, you should get the same results.

That being said, I would recommend not to use the seeding approach to compare models, as you would need to be familiar which layers are initialized in which order.
The better approach is just to load the state_dict from one model into the other and test both models.

3 Likes

Thank you so much @ptrblck, it is working now! :slight_smile:

Usually I wouldn’t use seeding to compare models but this time I just wanted to make sure that my implementation was correct. Therefore, I figured I just check the results using the same seed. But I see how this can lead to issues, as in this case for example.

Anyways, once again, thank you for the support!

All the best
snowe