Split image into Non-overlapping sub-blocks

Hi Guys!

I am basically trying to split an MNIST Image Into smaller non-overlapping square blocks of size block_size and trying to shuffle pixel values in it randomly. However, a naive implementation from my using nested for loops causes the function to be terribly slow. I have tried reading the docs for some functions in torch that can help me out; I couldn’t find any. Could Someone please help me with doing this using slicing? Any help would be highly appreciated! Thanks in advance! I am attaching the code for reference.

#takes a batch of images (BxCxHXW) and shuffles pixels
#I was executing this on mnist, hence I didn't consider the channel dimension
def Pixel_Shuffle(batch,block_size=2):

  it = int(batch.shape[2]/block_size)

  for img in batch :

    for i in range(it**2) :

      row = int(i/it)
      row = row*2
      col = i%it
      col = col*block_size
      img[0,row:row+block_size,col:col+block_size] = Block_Shuffle(img[0,row:row+block_size,col:col+block_size])

return batch

def Block_Shuffle(block):
  ord = torch.randperm(block.numel())
  block = block.reshape(-1)[ord].reshape(block.shape)
  return block

You could use tensor.unfold to create the patches and reshape them back as explained in this post.