Optimized sliding window operations with GPU

Hello,
I am trying to implement this paper, and the idea is, let’s say you have a 33x33 image and you take a 3x3 kernel, and your windows do not overlap (stride 3). So it creates 11 * 11=121 number of 3x3 windows for each of the 3 channels. For each window, we find all the possible rotations around the center, calculate the Local Binary Pattern value, and switch the window with rotation, resulting in minimum LBP. So, I have batch size 128, and I have created the matrix for sliding window that has dimension 128x121x3x3x3, and min_lpb function that finds the correct rotation. I need to apply this function to all 128 * 121 * 3 windows and replace each with its correct rotation form. But I do not know how I can vectorize the function to run parallel for all the windows on the GPU. I tried the vmap but couldn’t get the mapping over the last two dimensions.

Here is my custom layer:

class RRL(nn.Module):
  def __init__(self, kernel_size=3, padding=(0, 0)):
    super(RRL, self).__init__()
    self.FH = self.FW = self.stride = kernel_size
    self.padding = padding

  def forward(self, x):
    # x -> (batch_size, channels, height, width)
    # OH, OW - feature map height and width
    # FH, FW - kernel height and width

    batch_size, ch, H, W = x.shape
    OH = H // self.FH
    OW = W // self.FW

    # pad the input
    x = F.pad(x, pad=self.padding)
    
    # create sliding windows
    x = x.unfold(2, self.FH, self.stride).unfold(3, self.FW, self.stride) # x -> (batch_size, channels, OH, OW, FH, FW)
    x = x.contiguous().view(batch_size, ch, -1, self.FH, self.FW) # x -> (batch_size, channels, OH*OW, FH, FW)
    x = x.permute(0, 2, 1, 3, 4) # x -> (batch_size, OH*OW, channels, FH, FW)

    
    # calculate minimum LPB state for all the windows 
    #x = vmap(min_lpb, in_dims=(-2, -1), out_dims=(-2, -1))(x) # x -> (batch_size, OH*OW, channels, FH, FW)

    num_cores = multiprocessing.cpu_count()
    x = Parallel(n_jobs=num_cores, backend="threading")(delayed(min_lpb)(x[b, o, c, :, :]) for b in range(batch_size) for o in range(OH*OW) for c in range(ch))
    x = torch.stack(x).view(batch_size, OH*OW, ch, self.FH, self.FW)

    # reshape matrix into original format
    x = x.permute(0, 2, 1, 3, 4) # x -> (batch_size, channels, OH*OW, FH, FW)
    x = x.view(batch_size, ch, OH, OW, self.FH, self.FW) # x -> (batch_size, channels, OH, OW, FH, FW)
    x = x.permute(0, 1, 2, 4, 3, 5) # x -> (batch_size, channels, OH, FH, OW, FW)
    x = x.contiguous().view(batch_size, ch, H, W) # x -> (batch_size, channels, OH*FH, OW*FW)
 
    return x

For vmap, I have tried vmap(min_lpb, in_dims=(-2, -1), out_dims=(-2, -1)(x) but in_dims gives an error, it should be nested structure not a tuple, same as input tensor that has structure TreeSpec. I have tried giving the integer to in_dims, I gave -3, thinking it can keep the last 2 dimensions and apply for each channel. I also tried x=x.unsqueeze(3), and specified in_dims=3 so that it can map over all the other dimensions, and take the last 2 dimensions for each case. But none works, and I am not familiar with any other vectorization in Pytorch or have an idea what else I can use. I adapted multiprocessing from this discussion, but it is so slow. P.S I have used an unfolding idea from the discussion named How to implement a convolutional layer. (I can only put two links, sorry)

Here is the function I am applying to each sliding window:

def rotate_mx(mx, degree):
	return torchvision.transforms.functional.rotate(mx.unsqueeze(0), degree).squeeze(0)

def LPB(mx):
  mid = mx.shape[0] // 2
  threshold = mx[mid, mid]

  bin_mx = (mx >= threshold)
  power_mx = torch.Tensor([
                       [1, 2, 4],
                       [8, 0, 16],
                       [32, 64, 128]
  ]).to(device)

  return (bin_mx * power_mx).sum()

def min_lpb(mx):
  num_of_surr_elements = mx.shape[0] ** 2 - 1
  rot_deg = 360 // num_of_surr_elements

  min_mx = mx.clone()
  lpb = LPB(mx)
  for degree in range(rot_deg, 360, rot_deg):
    mx = rotate_mx(mx, degree)
    cur_lpb = LPB(mx)

    if cur_lpb < lpb:
      lpb = cur_lpb
      min_mx = mx.clone()
     
  return min_mx

Can you please suggest how I can effectively apply this function to all windows? I looked at GEMM, but I don’t think I can create a matrix to do this minimum LPB-based rotation.