Hi Guys!
I am basically trying to split an MNIST Image Into smaller non-overlapping square blocks of size block_size
and trying to shuffle pixel values in it randomly. However, a naive implementation from my using nested for loops causes the function to be terribly slow. I have tried reading the docs for some functions in torch
that can help me out; I couldn’t find any. Could Someone please help me with doing this using slicing? Any help would be highly appreciated! Thanks in advance! I am attaching the code for reference.
#takes a batch of images (BxCxHXW) and shuffles pixels
#I was executing this on mnist, hence I didn't consider the channel dimension
def Pixel_Shuffle(batch,block_size=2):
it = int(batch.shape[2]/block_size)
for img in batch :
for i in range(it**2) :
row = int(i/it)
row = row*2
col = i%it
col = col*block_size
img[0,row:row+block_size,col:col+block_size] = Block_Shuffle(img[0,row:row+block_size,col:col+block_size])
return batch
def Block_Shuffle(block):
ord = torch.randperm(block.numel())
block = block.reshape(-1)[ord].reshape(block.shape)
return block