I am trying to apply ZCA whitening matrix to my dataset. The code I use is as follows:

```
import torchvision
import torch
import torchvision.transforms as transforms
from torchvision import transforms, datasets, models
import matplotlib.pyplot as plt
import numpy as np
def show(i):
i = i.reshape((32,32,3))
m,M = i.min(), i.max()
plt.imshow((i - m) / (M - m))
plt.show()
def computeZCAMAtrix():
#This function computes the ZCA matrix for a set of observables X where
#rows are the observations and columns are the variables (M x C x W x H matrix)
#C is number of color channels and W x H is width and height of each image
root = 'cifar10/'
temp= datasets.CIFAR10(root = root,
train = True,
download = True)
#normalize the data to [0 1] range
temp.train_data=temp.train_data/255
#compute mean and std and normalize the data to -1 1 range with 1 std
mean=(temp.train_data.mean(axis=(0,1,2)))
std=(temp.train_data.std(axis=(0,1,2)))
temp.train_data=np.multiply(1/std,np.add(temp.train_data,-mean))
#reshape data from M x C x W x H to M x N where N=C x W x H
X = temp.train_data
X = X.reshape(-1, 3072)
# compute the covariance
cov = np.cov(X, rowvar=False) # cov is (N, N)
# singular value decomposition
U,S,V = np.linalg.svd(cov) # U is (N, N), S is (N,1) V is (N,N)
# build the ZCA matrix which is (N,N)
epsilon = 1e-5
zca_matrix = np.dot(U, np.dot(np.diag(1.0/np.sqrt(S + epsilon)), U.T))
return (torch.from_numpy(zca_matrix).float(), mean, std)
#this transformation is used to transform the images to [0,1] range
#then normalize to 0 mean and 1 std, and then some
#random transformation to boost data variety at each epoch
batch_size=4
(Z,mean,std) = computeZCAMAtrix()
root = 'cifar10/'
transform_train = transforms.Compose(
[
transforms.ToTensor(),
transforms.LinearTransformation(Z),
])
transform_test = transforms.Compose(
[ transforms.ToTensor(),
transforms.LinearTransformation(Z),
])
#get the training and test sets
training_set = datasets.CIFAR10(root = root,
transform = transform_train,
train = True,
download = True)
test_set = datasets.CIFAR10(root = root,
transform = transform_test,
train = False,
download = True)
#trying to apply the transformation manually to see the result it produces
X = training_set.train_data
X = X.reshape(-1, 3072)
a=X[1]
zca2=np.dot(a,Z)
print('printing truck')
show2(zca2)
training_loader = torch.utils.data.DataLoader(dataset=training_set,
batch_size=batch_size,
shuffle=False
)
test_loader = torch.utils.data.DataLoader(dataset=test_set,
batch_size=batch_size,
shuffle=False
)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# get some random training images
dataiter = iter(training_loader)
for i, (images, labels) in enumerate(training_loader ,1):
a=images[0,:,:,:]
b=images[1,:,:,:]
c=images[2,:,:,:]
d=images[3,:,:,:]
#show images
#imshow(torchvision.utils.make_grid(images))
print('%5s' % classes[labels[2]])
show(b.numpy())
break
```

As you see I compute the ZCAmatrix and feed it as input (Z) to the dataloading function. However in computation of the ZCA matrix I assume that the data is of the form 50000 x H x W x C. However if I want to apply it as a transformation during dataloading with transforms.LinearTransformation(Z), it seems I need to first convert it to tensor using ToTensor which reorders data as 50000 x C x H x W. Then the application of the ZCA matrix to the data points reorder and flattened this way does not produce what I want it to produce. For instance if I manually apply the transformation to a truck picture of the form H x W x C I get the first picture below where as the LinearTransformation produces the second picture.

How can I work around this? it also seems I need to produce the ZCA matrix using 50000 x H x W x C as reordering data the other form and computing the ZCA matrix does not produce correct results as well.