Implementing ZCA whitening using torchvision.transforms

I am trying to implement ZCA whitening on CIFAR10. I know how to implement ZCA use normal flow. But I want to construct a transformer which can be used in DataLoader().
I compute data’s principle vectors (use torch.svd()) and the V matric can be used in torch.transforms.LinearTransformer(). However, after rotating, normalized each dimension (multiply a factor 1/(lambda)^0.5 ) is needed in ZCA algorithm. But the normalized() function in torchvision.transforms can only normalize each channel, not each individual pixel(dimension).
After some derivation, I think is the same to first multiply the V matric by a factor 1/D*(lambda)^0.5, in which D is the number of dimensions and lambda is the eigenvalues.

It seems complicated to whitening data. I don’t know which part I’m wrong.

1 Like