I am basically trying to split an MNIST Image Into smaller non-overlapping square blocks of size
block_size and trying to shuffle pixel values in it randomly. However, a naive implementation from my using nested for loops causes the function to be terribly slow. I have tried reading the docs for some functions in
torch that can help me out; I couldn’t find any. Could Someone please help me with doing this using slicing? Any help would be highly appreciated! Thanks in advance! I am attaching the code for reference.
#takes a batch of images (BxCxHXW) and shuffles pixels #I was executing this on mnist, hence I didn't consider the channel dimension def Pixel_Shuffle(batch,block_size=2): it = int(batch.shape/block_size) for img in batch : for i in range(it**2) : row = int(i/it) row = row*2 col = i%it col = col*block_size img[0,row:row+block_size,col:col+block_size] = Block_Shuffle(img[0,row:row+block_size,col:col+block_size]) return batch def Block_Shuffle(block): ord = torch.randperm(block.numel()) block = block.reshape(-1)[ord].reshape(block.shape) return block