Applying manual ZCA whitening is very slow

I am trying to write a clasifier for CIFAR10 images where the images are also whitened at the loader level. I have a seperate function

def computeZCAMAtrix(dataname):

    #compute ZCA whitening matrix, mean and std for the given database

    return (torch.from_numpy(ZCAMatrix).float(), mean, std)

which I call before loading my data. The data loader function is of the form

def getData(batch_size,dataname,Z,mean,std): 
  
  
    if(batch_size<=0):
        batch_size=256

    if(dataname=='CIFAR10'):
        root = 'cifar10/'

        
        transform_train = transforms.Compose(
        [     transforms.RandomRotation(30),
              transforms.RandomHorizontalFlip(),
              transforms.RandomCrop(32, padding=4), 
              transforms.ToTensor(),
              transforms.Normalize(mean , std),
              transforms.LinearTransformation(Z), 
              ])

        #for test set we do not apply the random transformations
        transform_test = transforms.Compose(
        [     transforms.ToTensor(),
              transforms.Normalize(mean , std),
              transforms.LinearTransformation(Z)
              ])
        
        
        #load the training and test sets
        training_set = datasets.CIFAR10(root = root,
                                  transform = transform_train,
                                  train = True,
                                  download = True)
        
        test_set = datasets.CIFAR10(root = root,
                                  transform = transform_test,
                                  train = False,
                                  download = True)
 
    else:
        printf('Currently only CIFAR10 is allowed, terminating program')
        sys.exit()
        

    training_loader = torch.utils.data.DataLoader(dataset=training_set,
                                              batch_size=batch_size,
                                              shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_set,
                                            batch_size=batch_size,
                                            shuffle=False)

    
    return (training_set,test_set,training_loader,test_loader)

So the training loader applies some random transformations and then does the normalization. For some reason this slows down the code considerably. Is the application of transforms.LinearTransformation really that slow or am I making a mistake. I take the batch size to be 256. I am trying the code on google colab, GPU enabled. Should apply the whitening transormation to the dataset and save it somewhere and the use it? That makes sense and I can just apply the ZCA whitening matrix at the level of np multplication to the train_data and test_data but unfortuntely I do not know how to save transformed datasets.

By the way applyin ZCA whitening worsens the accuracy quite considerably as well. I assume this is because I apply it after the random transormations?

I have the same issue, linear transform for cifar with a transformation matrix of size 3*1024 is actually unusably slow. It seems as pytorch does not prefetch samples while the gpu is running.

Any news on this front? Applying ZCA whitening through LinearTransformation is unusably slow for me as well.

If you use ZCA in DataLoader, it would be very slow. Instead, I wrote a ZCA Transformation here, and it is used per minibatch, which is faster. The usage of ZCA Transformation is as below.

for data, target in dataloader:
     zca_data = zca(data)

I couldn’t find how you compute the transformation matrices in your repo. Could you elaborate more on how to define zca to work on batches of data?

You can see how to calculate mean and ZCA mat from this repo. After calculating, save the mean and ZCA mat using scipy.io.save() api. Then, when training, you can load the mean and mat using this code.

2 Likes