Hi All,
I am training ResNet on CIFAR-10 dataset and doing TenCrop data augmentation as ResNet suggests. However, I found that after TenCrop (please see the implementation of my dataset below), 10 pictures cropped from the same picture will be glued together in a batch even I turn on shuffle=True
. I am wondering if there is any way to solve this (i.e. to treat 10 pictures independently as we normally do for training pictures: crop 1 in batch i, crop 2 in batch j,…, instead of crop 1-10 all in the same batch)? Or for the training, this does not matter? I thought this problem matters since glueing all cropped pictures will decrease the randomness/variance of gradients which is more likely to be overfitting according to the second answer of this thread, but I am not sure.
My implementation:
class CIFAR10(Dataset):
def __init__(self, data_path, dataset, data_aug=False):
"""
data_path: folder storing train, valid and test set
dataset: "train", "valid" or "test"
data_aug: if use data augmentation, default: False
"""
# initialize object variables
self.data_aug = data_aug
# read in the dataset
with open(f"{data_path}/{dataset}_set.pkl", "rb") as fin:
self.cifar_imgs = pickle.load(fin)
# define transform
# we keep the original image size 32*32, same as experiments in ResNet paper
# Note that after ToTensor, the value of each pixel becomes [0,1]
# then we can apply Normalize
if data_aug:
# Here the data augmentation is based on https://arxiv.org/pdf/1409.5185.pdf
# Also, please take a look at PyTorch doc about how to solve dimension
# problems due to tuples returned by TenCrop()
self.transformations = transforms.Compose([
transforms.Pad(padding=4),
transforms.TenCrop(32, vertical_flip=False), # return a tuple of 10 PIL images
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), # convert the tuple into [B, C, H, W]
transforms.Lambda(lambda tensors: torch.stack([transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))(t) for t in tensors]))
]) # Note: valid and test should not use data aug
else:
self.transformations = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])
def __getitem__(self, index):
img = Image.fromarray(self.cifar_imgs[index][0]) # convert to PIL image
label = self.cifar_imgs[index][1]
# transform the image
img = self.transformations(img)
return img, label
def __len__(self):
# notice that we cannot return the length of the list after data augmentation
# otherwise, the sampler will sample from 1 to length after data augmentation
# which will cause out of range error when getting items
return len(self.cifar_imgs)
Also, I used the trick in the PyTorch doc to deal with the inconsistency of dimension (4D vs. 5D) as follows:
if if_aug:
bs, ncrops, c, h, w = img.size()
img = img.view(-1, c, h, w)
label = torch.repeat_interleave(label, 10)