Randomize TenCrop Data Augmentation

Hi All,

I am training ResNet on CIFAR-10 dataset and doing TenCrop data augmentation as ResNet suggests. However, I found that after TenCrop (please see the implementation of my dataset below), 10 pictures cropped from the same picture will be glued together in a batch even I turn on shuffle=True. I am wondering if there is any way to solve this (i.e. to treat 10 pictures independently as we normally do for training pictures: crop 1 in batch i, crop 2 in batch j,…, instead of crop 1-10 all in the same batch)? Or for the training, this does not matter? I thought this problem matters since glueing all cropped pictures will decrease the randomness/variance of gradients which is more likely to be overfitting according to the second answer of this thread, but I am not sure.

My implementation:

class CIFAR10(Dataset):
    def __init__(self, data_path, dataset, data_aug=False):
        """
        data_path: folder storing train, valid and test set
        dataset: "train", "valid" or "test"
        data_aug: if use data augmentation, default: False
        """
        # initialize object variables
        self.data_aug = data_aug

        # read in the dataset
        with open(f"{data_path}/{dataset}_set.pkl", "rb") as fin:
            self.cifar_imgs = pickle.load(fin)

        # define transform
        # we keep the original image size 32*32, same as experiments in ResNet paper
        # Note that after ToTensor, the value of each pixel becomes [0,1]
        # then we can apply Normalize
        if data_aug:
            # Here the data augmentation is based on https://arxiv.org/pdf/1409.5185.pdf
            # Also, please take a look at PyTorch doc about how to solve dimension
            # problems due to tuples returned by TenCrop()
            self.transformations = transforms.Compose([
                transforms.Pad(padding=4),
                transforms.TenCrop(32, vertical_flip=False),  # return a tuple of 10 PIL images
                transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),  # convert the tuple into [B, C, H, W] 
                transforms.Lambda(lambda tensors: torch.stack([transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))(t) for t in tensors]))
                ])  # Note: valid and test should not use data aug
        else:
            self.transformations = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
            ])    
    
    def __getitem__(self, index):
        img = Image.fromarray(self.cifar_imgs[index][0]) # convert to PIL image
        label = self.cifar_imgs[index][1]
        
        # transform the image
        img = self.transformations(img)

        return img, label

    def __len__(self):  
        # notice that we cannot return the length of the list after data augmentation
        # otherwise, the sampler will sample from 1 to length after data augmentation
        # which will cause out of range error when getting items
        return len(self.cifar_imgs) 

Also, I used the trick in the PyTorch doc to deal with the inconsistency of dimension (4D vs. 5D) as follows:

if if_aug:
   bs, ncrops, c, h, w = img.size()
   img = img.view(-1, c, h, w)
   label = torch.repeat_interleave(label, 10)