How could I make this function work in parallel?

Hi folks! I have a function that iterates over an image using two for loops, grabbing a small local area of the image each time, operating on it, and in the end, I stitch all the new tensors together. This function is bottlenecking my code as it does not run in parallel.

Lets say I have an image tensor: [4 x 4 x 1] (h x w x grayscale value). If I instead split the tensor into [ 2 x 2 x 4], how would I then create a function that processes the 4 tensors individually?

I would like to replace the for loop in the code below with a function that takes the chopped up image.


def wiener_3d(I, noise_std, block_size):
    # torch.set_default_tensor_type('torch.cuda.FloatTensor')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    # The dimensions of the tensor are pulled from the tensor.
    width = I.shape[1]
    height = I.shape[0]

    IR = torch.zeros(height, width, dtype=torch.float64)

    # if(len(list(I.shape)) >= 3):
    #    frames = I.shape[2]
    # else:
    #   bt = 1

    bt = 1
    bx = block_size
    by = block_size

    # Half block sizes to index the windowing function from the centre.
    hbx = bx/2
    hby = by/2
    hbt = bt/2

    # Calculate the number of windows that will fit in the input image to index across
    # the image.
    sx = (width + hbx - 1)/hbx
    sy = (height + hby - 1)/hby

    # Initialise the block to a tensor of ones.
    win = torch.ones(by, bx, bt)
    # Create the cosine windows in each dimension.
    win1x = torch.cos((torch.arange(-hbx + .5, hbx - .5 + 1)/bx) * np.pi)
    win1y = torch.cos((torch.arange(-hby + .5, hby - .5 + 1)/by) * np.pi)
    win1t = torch.cos((torch.arange(-hbt + .5, hbt - .5 + 1)/bt) * np.pi)

    # Create the 2D cosine window 
    for x in range(bx):
        for y in range(by):
            for t in range(bt):
                win[y, x, t] = win1y[y]*win1x[x]*win1t[t]

    if(bt == 1):
        win = torch.squeeze(win)

    # Pvv is our noise power. Here we estimate the noise power across the window.
    # Pvv is a random variable that we do not have to reestimate as we move across the
    # image in denoising. 
    Pvv = torch.mean(torch.pow(win, 2))*torch.numel(win)*(noise_std**2)
    Pvv = Pvv.double()
    # bx0 is a list of valid indices for the input image.
    bx0 = torch.range(0, bx-1)
    by0 = torch.range(0, by-1)



    for x in range(0, int((hbx*sx)), int(hbx)):
        for y in range(0, int((hby*sy)), int(hby)):

            # The first part of the code deals with finding valid block coordinates by ensuring the silding
            # window does not try to pick values outside of the input image range.

            # tx, ty         =  The current block range, centered at x
            # validx, validy =  Takes the range of tx and ty that are valid, i.e. that are not less than 0 or
            #                   greater than the output image size.
            # cx             =  This corresponds to coordinates in the original images for the full blocks
            #                   size centered at x. tx is used to clamp the
            #                   allowable values.
            # rcx, rcy       =  Final valid coordinates in the image, used to recontruct the new image.
            #                   Final valid coordinates in the block.
            tx = np.arange(x-hbx+1, x+hbx+1)
            validx = np.arange(np.maximum(-tx[0], 0), bx - np.maximum((tx[-1]-width+1), 0))
            cx = np.minimum(np.maximum(tx, 0), width - 1)
            validx = validx.astype(int)
            rcx = torch.as_tensor(tx[validx], dtype=torch.long)
            bcx = torch.as_tensor(bx0[validx], dtype=torch.long)

            ty = np.arange(y-hby+1, y+hby+1)
            validy = np.arange(np.maximum(-ty[0], 0), by - np.maximum((ty[-1]-width+1), 0))
            cy = np.minimum(np.maximum(ty, 0), width-1)
            validy = validy.astype(int)
            rcy = torch.as_tensor(ty[validy], dtype=torch.long)
            bcy = torch.as_tensor(by0[validy], dtype=torch.long)

            cy = torch.as_tensor(cy, dtype=torch.long)
            cx = torch.as_tensor(cx, dtype=torch.long)
            # We use index select to isolate only the range we are interested in.
            data_block = torch.index_select(I, 0, cy)
            data_block = torch.index_select(data_block, 1, cx)

            # Here we are creating and applying the Wiener function.
            # Firstly, we find the mean.
            # Next, we zero mean the data and window the data in the block.
            mean_block = torch.mean(data_block)
            win_data_block = (data_block - mean_block)*win

            # Next, we transform our windowed data to the frequency domain. We find the PSD by squaring the data
            # In Pytorch, we must sum the imaginary and real numbers also.
            freq_block = torch.rfft(win_data_block, win_data_block.ndim, onesided=False)
            Pss = torch.abs(freq_block)**2
            Pss = torch.sum(Pss, 2)
            Pss = Pss.double()

            # We add a tiny 'epsilon' value before proceeding. This is necessary to avoid zero value divisions which
            # returns NA in Pytorch and halts backpropagation.
            eps = 1e-15
            Pss = Pss + eps

            # For every pixel, we find the Wiener function as the difference between the signal PSD and the noise PSD.
            # If the noise PSD is greater than the signal PSD, we choose instead to make the Wiener function zero at
            # that point. We divide the output by the signal PSD to normalise each value.
            H = torch.max((Pss-Pvv), torch.zeros(Pss.size(), dtype=torch.double))
            H = H / Pss

            H = H.unsqueeze(2).repeat(1, 1, 2)

            # We apply the Wiener funtion to our windowed frequency domain data. Next we transform the data back to the
            # spatial domain using the inverse fft 'irrft().
            filt_freq_block = H*freq_block
            filt_data_block = torch.irfft(filt_freq_block, win_data_block.ndim, onesided=False)
            # We remember to add the mean offset back onto the the data. We interpolate between blocks. win^2 can be also
            # be used as a spatial interpolant
            filt_data_block = (filt_data_block + mean_block*win) * win

            # We index the new filtered block to the valid pixels, and add it to our new image, IR using rcx and rcy
            filt_data_block = torch.index_select(filt_data_block, 0, bcy)
            filt_data_block = torch.index_select(filt_data_block, 1, bcx)
            IR[rcy[0]:rcy[-1] + 1, rcx[0]:rcx[-1] + 1] = IR[rcy[0]:rcy[-1] + 1, rcx[0]:rcx[-1] + 1] + filt_data_block
    return IR