Pytorch resnet bn problem

Hi, I’m using pytorch’s pretrained resnet50 model, and i came across some problems with the output.

Assume I train for 10 epches and the best validation epoch 5, so i saved epoch 5’s model, and re-run the saved epoch 5’s model on the validation, i found this prediction is different with the prediction in the training process(of course also using the epoch 5).

After some investigation, it seems the problems comes from a layer(backbone.layer1) which contains a bn, before this layer, the input is the same, after this layer, small errors start appearing and then the error is propagated to the later layers until the output.

during both evaluation phase(one is in training, and one is after training), i set model.eval(), so I’m confused why there is still difference?

I just tried to reproduce this issue and it seems to work fine:

model = models.resnet50(pretrained=True)

data = torch.randn(10, 3, 224, 224)
target = torch.randint(0, 1000, (10,))

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

# Train for 5 epochs
for epoch in range(5):
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    print('Epoch {}, loss {}'.format(epoch, loss.item()))

# Predict on validation data
val_data = torch.randn(10, 3, 224, 224)
model.eval()
pred1 = model(val_data)

# Save and restore model
torch.save(model.state_dict(), 'tmp.pt')

model = models.resnet50(pretrained=False)
model.load_state_dict(torch.load('tmp.pt'))

# Predict on same validation data
model.eval()
pred2 = model(val_data)

# Compare
print((pred1 == pred2).all())
> tensor(1, dtype=torch.uint8)

Could you compare your code with mine and check for differences or if I’m missing something?
How large is the difference? Note that you might encounter some floating point precision issues.

Thanks for the reply, i think at the beginning, the difference is small: around 1e-8, but after later complex network structure, seems the difference will become 1e-2 around… Is there any way to get around this kind of propogation?

I think it’s better to share some short code samples for better understanding:

def _load_resnet_imagenet(pretrained=True):
    # huge thx to https://github.com/ruotianluo/pytorch-faster-rcnn/blob/master/lib/nets/resnet_v1.py
    backbone = resnet.resnet50(pretrained=pretrained)
    for i in range(2, 4):
        getattr(backbone, 'layer%d' % i)[0].conv1.stride = (2, 2)
        getattr(backbone, 'layer%d' % i)[0].conv2.stride = (1, 1)
    # use stride 1 for the last conv4 layer (same as tf-faster-rcnn)
    backbone.layer4[0].conv2.stride = (1, 1)
    backbone.layer4[0].downsample[0].stride = (1, 1)
    return backbone

 backbone = _load_resnet_imagenet(pretrained=pretrained)

 self.backbone = nn.Sequential(
            backbone.conv1,
            backbone.bn1,
            backbone.relu,
            backbone.maxpool,
            backbone.layer1,
            backbone.layer2,
            backbone.layer3,
            # backbone.layer4
        )

I found the output difference starts from backbone.layer1. this layer’s input is the same, but output is a little bit different. and I also check the training flag for this model when i do evaluation, both are false.

Thanks for the code!

Could you try to pass some cuDNN flags in case you are using the GPU as described in the Reproducibility docs and see if the error lowers?
The initial error suggests floating point precision issues.
Could you just for debugging purposes also try to cast all parameters and the data input to .double() and check the error again?

Hi, changing from the float to double seems works. but i’m still confused how this can happen…