Hi folks! I have a function that iterates over an image using two for loops, grabbing a small local area of the image each time, operating on it, and in the end, I stitch all the new tensors together. This function is bottlenecking my code as it does not run in parallel.
Lets say I have an image tensor: [4 x 4 x 1] (h x w x grayscale value). If I instead split the tensor into [ 2 x 2 x 4], how would I then create a function that processes the 4 tensors individually?
I would like to replace the for loop in the code below with a function that takes the chopped up image.
def wiener_3d(I, noise_std, block_size):
# torch.set_default_tensor_type('torch.cuda.FloatTensor')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# The dimensions of the tensor are pulled from the tensor.
width = I.shape[1]
height = I.shape[0]
IR = torch.zeros(height, width, dtype=torch.float64)
# if(len(list(I.shape)) >= 3):
# frames = I.shape[2]
# else:
# bt = 1
bt = 1
bx = block_size
by = block_size
# Half block sizes to index the windowing function from the centre.
hbx = bx/2
hby = by/2
hbt = bt/2
# Calculate the number of windows that will fit in the input image to index across
# the image.
sx = (width + hbx - 1)/hbx
sy = (height + hby - 1)/hby
# Initialise the block to a tensor of ones.
win = torch.ones(by, bx, bt)
# Create the cosine windows in each dimension.
win1x = torch.cos((torch.arange(-hbx + .5, hbx - .5 + 1)/bx) * np.pi)
win1y = torch.cos((torch.arange(-hby + .5, hby - .5 + 1)/by) * np.pi)
win1t = torch.cos((torch.arange(-hbt + .5, hbt - .5 + 1)/bt) * np.pi)
# Create the 2D cosine window
for x in range(bx):
for y in range(by):
for t in range(bt):
win[y, x, t] = win1y[y]*win1x[x]*win1t[t]
if(bt == 1):
win = torch.squeeze(win)
# Pvv is our noise power. Here we estimate the noise power across the window.
# Pvv is a random variable that we do not have to reestimate as we move across the
# image in denoising.
Pvv = torch.mean(torch.pow(win, 2))*torch.numel(win)*(noise_std**2)
Pvv = Pvv.double()
# bx0 is a list of valid indices for the input image.
bx0 = torch.range(0, bx-1)
by0 = torch.range(0, by-1)
for x in range(0, int((hbx*sx)), int(hbx)):
for y in range(0, int((hby*sy)), int(hby)):
# The first part of the code deals with finding valid block coordinates by ensuring the silding
# window does not try to pick values outside of the input image range.
# tx, ty = The current block range, centered at x
# validx, validy = Takes the range of tx and ty that are valid, i.e. that are not less than 0 or
# greater than the output image size.
# cx = This corresponds to coordinates in the original images for the full blocks
# size centered at x. tx is used to clamp the
# allowable values.
# rcx, rcy = Final valid coordinates in the image, used to recontruct the new image.
# Final valid coordinates in the block.
tx = np.arange(x-hbx+1, x+hbx+1)
validx = np.arange(np.maximum(-tx[0], 0), bx - np.maximum((tx[-1]-width+1), 0))
cx = np.minimum(np.maximum(tx, 0), width - 1)
validx = validx.astype(int)
rcx = torch.as_tensor(tx[validx], dtype=torch.long)
bcx = torch.as_tensor(bx0[validx], dtype=torch.long)
ty = np.arange(y-hby+1, y+hby+1)
validy = np.arange(np.maximum(-ty[0], 0), by - np.maximum((ty[-1]-width+1), 0))
cy = np.minimum(np.maximum(ty, 0), width-1)
validy = validy.astype(int)
rcy = torch.as_tensor(ty[validy], dtype=torch.long)
bcy = torch.as_tensor(by0[validy], dtype=torch.long)
cy = torch.as_tensor(cy, dtype=torch.long)
cx = torch.as_tensor(cx, dtype=torch.long)
# We use index select to isolate only the range we are interested in.
data_block = torch.index_select(I, 0, cy)
data_block = torch.index_select(data_block, 1, cx)
# Here we are creating and applying the Wiener function.
# Firstly, we find the mean.
# Next, we zero mean the data and window the data in the block.
mean_block = torch.mean(data_block)
win_data_block = (data_block - mean_block)*win
# Next, we transform our windowed data to the frequency domain. We find the PSD by squaring the data
# In Pytorch, we must sum the imaginary and real numbers also.
freq_block = torch.rfft(win_data_block, win_data_block.ndim, onesided=False)
Pss = torch.abs(freq_block)**2
Pss = torch.sum(Pss, 2)
Pss = Pss.double()
# We add a tiny 'epsilon' value before proceeding. This is necessary to avoid zero value divisions which
# returns NA in Pytorch and halts backpropagation.
eps = 1e-15
Pss = Pss + eps
# For every pixel, we find the Wiener function as the difference between the signal PSD and the noise PSD.
# If the noise PSD is greater than the signal PSD, we choose instead to make the Wiener function zero at
# that point. We divide the output by the signal PSD to normalise each value.
H = torch.max((Pss-Pvv), torch.zeros(Pss.size(), dtype=torch.double))
H = H / Pss
H = H.unsqueeze(2).repeat(1, 1, 2)
# We apply the Wiener funtion to our windowed frequency domain data. Next we transform the data back to the
# spatial domain using the inverse fft 'irrft().
filt_freq_block = H*freq_block
filt_data_block = torch.irfft(filt_freq_block, win_data_block.ndim, onesided=False)
# We remember to add the mean offset back onto the the data. We interpolate between blocks. win^2 can be also
# be used as a spatial interpolant
filt_data_block = (filt_data_block + mean_block*win) * win
# We index the new filtered block to the valid pixels, and add it to our new image, IR using rcx and rcy
filt_data_block = torch.index_select(filt_data_block, 0, bcy)
filt_data_block = torch.index_select(filt_data_block, 1, bcx)
IR[rcy[0]:rcy[-1] + 1, rcx[0]:rcx[-1] + 1] = IR[rcy[0]:rcy[-1] + 1, rcx[0]:rcx[-1] + 1] + filt_data_block
return IR