I am trying to remove the structure in images (say a 28x28 MNIST digit image) while keeping the distribution of each pixel the same. To achieve this I need to independently permute each pixel along the batch dimension. I could use torch.randperm to shuffle the indices for each pixel or numpy.random.permutation to do the permutation directly. However both of these functions only operate on the first dimension of the tensor which means I would need to run them in a 28x28 for loop.

Is there a more computationally efficient way to do this?

If you are permuting along the batch dimension each sample will have pixel information of some other samples from the batch. Is that your intention or do you rather want to permute the pixels in each sample in a defined manner?

In the latter case, you can just use this code sample:

dataset = datasets.MNIST(root='./data',
download=False,
transform=transforms.ToTensor())
shuffle_idx = torch.randperm(28*28)
data = [dataset[i][0] for i in range(10)]
target = torch.stack([dataset[i][1] for i in range(10)], dim=0)
# Show first sample
plt.figure()
plt.imshow(data[0][0])
# Permute the pixels
data = torch.stack([x.view(-1)[shuffle_idx].view(1, 28, 28) for x in data], dim=0)
# Show after permutation
plt.figure()
plt.imshow(data[0][0])

Thank you for your reply. However correct me if Iām wrong but this code will permute all the pixels within the same image. This means that pixels in the corner of the image which are typically always dark could be swapped for pixels in the middle which have much more variation. What I am trying to achieve is for instance to randomly replace the top left corner pixel of image 1 with the top left corner pixel of another image within the same batch and so on for all pixels so as to preserve the distribution of each pixel but remove the dependencies between pixels of the same image.

The best implementation I could come up with is below but I am wondering if there is a more computationally efficient solution which avoids iterating over every pixel.

from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import torch
dataset = datasets.MNIST(root='./data',
download=True,
transform=transforms.ToTensor())
n = 10
data = torch.stack([dataset[i][0] for i in range(n)], dim=0)
target = torch.stack([dataset[i][1] for i in range(n)], dim=0)
# Show first sample
plt.figure()
plt.imshow(data[0][0])
# Permute the pixels
data_pixelshuffled = torch.stack([x[torch.randperm(n)] for x in data.view(n,-1).t()],
dim=0).t().view(-1,28,28)
# Show after permutation
plt.figure()
plt.imshow(data_pixelshuffled[0])
plt.show()