ResNet with CIFAR10 only reaches 86% accuracy (expecting >90%)

Hello everyone,

I am trying to reproduce the numbers from the original ResNet publication on CIFAR10. I am using the network implementation from here:

As far as I can tell, I am using the exact training parameters that are given in the paper:

We use a weight decay of 0.0001 and momentum of 0.9,
and adopt the weight initialization in [13] and BN [16] but
with no dropout. These models are trained with a mini-
batch size of 128 on two GPUs. We start with a learning
rate of 0.1, divide it by 10 at 32k and 48k iterations, and
terminate training at 64k iterations, which is determined on
a 45k/5k train/val split. We follow the simple data augmen-
tation in [24] for training: 4 pixels are padded on each side,
and a 32×32 crop is randomly sampled from the padded
image or its horizontal flip.

Here are the relevant parts of my training script:

class ResNet(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = resnet.resnet20()

    def forward(self, x): 
        return self.model(x)

    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[90,135])
        return {'optimizer': optimizer, 'scheduler': scheduler}
...

class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, **kwargs):
        ...
        self.normalize = [
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262]),
        ]
        self.augment = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
        ]   
            
    def prepare_data(self):
        torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True)
        torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True)
            
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            t = transforms.Compose(self.augment + self.normalize)
            cifar_full = torchvision.datasets.CIFAR10(root=self.data_dir, train=True, transform=t)
            self.df_train, self.df_val = torch.utils.data.random_split(cifar_full, [45000, 5000])
        if stage == 'test' or stage is None:
            t = transforms.Compose(self.normalize)
            self.df_test = torchvision.datasets.CIFAR10(root=self.data_dir, train=False, transform=t)

dm = CIFAR10DataModule(batch_size=128, num_workers=4)
dm.setup("fit")
lr_logger = pl.callbacks.LearningRateMonitor()
mlflow.pytorch.autolog()
model = ResNet()
trainer = pl.Trainer(gpus=1, max_epochs=180, callbacks=[lr_logger])
trainer.fit(model, dm) 

However the accuracy only reaches around 86%, well below the 91.25% given in the original paper.

There is a comment in the repository that hosts the ResNet/CIFAR10 model which indicates that this issue seemed to occur after an update of PyTorch from version 1.1 to 1.2:

(note that the reported numbers in the issue refer to ResNet56, but the effect is the same, just less pronounced)

I suspect that I am seeing the same issue and would like to understand what is causing it and how I can best fix it.

1 Like

Note that this is validation accuracy, not test accucary. Now that I am thinking about it, I am wondering whether the drop in accuracy I am seeing is a side effect of the augmentation. I will report a value on the test set tomorrow!

I have evaluated against the test set and the effect stays the same: Training accuracy reaches around 93%, but test accuracy stagnates at around 85%.

I have found the issue, and it is a very subtle one:

When returning a scheduler to Lightning using the dict format like I do in this line:

        return {'optimizer': optimizer, 'scheduler': scheduler}

the keyword for the scheduler needs to be lr_scheduler, otherwise it will not be picked up and the learning rate will stay high.

2 Likes

i want use imagerfloder get cifar10,but the accuracy reaches 85%,expecting>96%