Hello everyone,
I am trying to reproduce the numbers from the original ResNet publication on CIFAR10. I am using the network implementation from here:
As far as I can tell, I am using the exact training parameters that are given in the paper:
We use a weight decay of 0.0001 and momentum of 0.9,
and adopt the weight initialization in [13] and BN [16] but
with no dropout. These models are trained with a mini-
batch size of 128 on two GPUs. We start with a learning
rate of 0.1, divide it by 10 at 32k and 48k iterations, and
terminate training at 64k iterations, which is determined on
a 45k/5k train/val split. We follow the simple data augmen-
tation in [24] for training: 4 pixels are padded on each side,
and a 32×32 crop is randomly sampled from the padded
image or its horizontal flip.
Here are the relevant parts of my training script:
class ResNet(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = resnet.resnet20()
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
optimizer = optim.SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[90,135])
return {'optimizer': optimizer, 'scheduler': scheduler}
...
class CIFAR10DataModule(pl.LightningDataModule):
def __init__(self, **kwargs):
...
self.normalize = [
transforms.ToTensor(),
transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262]),
]
self.augment = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
]
def prepare_data(self):
torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True)
torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True)
def setup(self, stage=None):
if stage == 'fit' or stage is None:
t = transforms.Compose(self.augment + self.normalize)
cifar_full = torchvision.datasets.CIFAR10(root=self.data_dir, train=True, transform=t)
self.df_train, self.df_val = torch.utils.data.random_split(cifar_full, [45000, 5000])
if stage == 'test' or stage is None:
t = transforms.Compose(self.normalize)
self.df_test = torchvision.datasets.CIFAR10(root=self.data_dir, train=False, transform=t)
dm = CIFAR10DataModule(batch_size=128, num_workers=4)
dm.setup("fit")
lr_logger = pl.callbacks.LearningRateMonitor()
mlflow.pytorch.autolog()
model = ResNet()
trainer = pl.Trainer(gpus=1, max_epochs=180, callbacks=[lr_logger])
trainer.fit(model, dm)
However the accuracy only reaches around 86%, well below the 91.25% given in the original paper.
There is a comment in the repository that hosts the ResNet/CIFAR10 model which indicates that this issue seemed to occur after an update of PyTorch from version 1.1 to 1.2:
(note that the reported numbers in the issue refer to ResNet56, but the effect is the same, just less pronounced)
I suspect that I am seeing the same issue and would like to understand what is causing it and how I can best fix it.