Hi K. Frank,
Thank you very much for your response. I’ve actually learned a bit from some of your other responses on the forum in previous work. I’ll give some more context to what I’m doing and more code. You can save time by skipping to the end to see what I’ve attempted. I think it could be too much to share everything needed to run a test of this module as an image and other modules are involved.
I’m training an autoencoder to denoise an image by
- Using the autoencoder/UNET to output a blurry mask/residual image.
- Subtracting this mask from the original image to get a noisy texture map.
- Filtering this noisy texture map in the frequency domain with a Wiener filter.
- Adding the output of (3) back onto the output of the autoencoder to restore detail to the final image. Basically, a two-part denoiser where the network implicitly learns to creates a mask over some parts of the image and lets the wiener filter take over.
I’m trying to make the Wiener filter learn a mapping for attenuating certain frequency magnitudes, instead of just making it a function of noise variance Pvv
to signal power spectrum Pss
.
Here is what I had before, where H is the filter vector to be applied to each value in the frequency block tensor:
freq_block = torch.rfft(win_data_block, win_data_block.ndim, onesided=False)
Pss = torch.abs(freq_block)**2
Pss = torch.sum(Pss, 2)
Pss = Pss.double()
H = torch.max((Pss-Pvv), torch.zeros(Pss.size(), dtype=torch.double)) / Pss
H = H.unsqueeze(2).repeat(1, 1, 2)
filt_freq_block = H*freq_block
This is comparing the power spectrum magnitude of the signal to the noise power and penalises the signal for a high noise. Pss represents a 16x16 window (lets say between 0 and 30) and Pvv is the constant noise power (lets say 1.5). I want to make a new frequency coring function based on a function with trainable parameters Here is the filter I am working with, with the problem we are discussing. Most of the code is dealing with indexing and figuring out boundaries so I’ll put the part I’m struggling with in bold:
import numpy as np
import torch
def add_noise(img, std):
noise = torch.randn(img.size()) * (std)
noisy_image = img + noise
return noisy_image
def wiener_3d(I, noise_std, block_size,k,s):
width = I.shape[1]
height = I.shape[0]
IR = torch.zeros(height, width, dtype=torch.float64)
# if(len(list(I.shape)) >= 3):
# frames = I.shape[2]
# else:
# bt = 1
bt = 1
bx = block_size
by = block_size
hbx = bx/2
hby = by/2
hbt = bt/2
sx = (width + hbx - 1)/hbx
sy = (height + hby - 1)/hby
win = torch.ones(by, bx, bt)
win1x = torch.cos((torch.arange(-hbx + .5, hbx - .5 + 1)/bx) * np.pi)
win1y = torch.cos((torch.arange(-hby + .5, hby - .5 + 1)/by) * np.pi)
win1t = torch.cos((torch.arange(-hbt + .5, hbt - .5 + 1)/bt) * np.pi)
for x in range(bx):
for y in range(by):
for t in range(bt):
win[y, x, t] = win1y[y]*win1x[x]*win1t[t]
if(bt == 1):
win = torch.squeeze(win)
Pvv = torch.mean(torch.pow(win, 2))*torch.numel(win)*(noise_std**2)
Pvv = Pvv.double()
bx0 = torch.range(0, bx-1)
by0 = torch.range(0, by-1)
for x in range(0, int((hbx*sx)), int(hbx)):
for y in range(0, int((hby*sy)), int(hby)):
# print(x,y)
tx = np.arange(x-hbx+1, x+hbx+1)
validx = np.arange(np.maximum(-tx[0], 0), bx - np.maximum((tx[-1]-width+1), 0))
cx = np.minimum(np.maximum(tx, 0), width-1)
validx = validx.astype(int)
rcx = torch.as_tensor(tx[validx], dtype=torch.long)
bcx = torch.as_tensor(bx0[validx], dtype=torch.long)
ty = np.arange(y-hby+1, y+hby+1)
validy = np.arange(np.maximum(-ty[0], 0), by - np.maximum((ty[-1]-width+1), 0))
cy = np.minimum(np.maximum(ty, 0), width-1)
validy = validy.astype(int)
rcy = torch.as_tensor(ty[validy], dtype=torch.long)
bcy = torch.as_tensor(by0[validy], dtype=torch.long)
cy = torch.as_tensor(cy, dtype=torch.long)
cx = torch.as_tensor(cx, dtype=torch.long)
data_block = torch.index_select(I, 0, cy)
data_block = torch.index_select(data_block, 1, cx)
mean_block = torch.mean(data_block)
win_data_block = (data_block - mean_block)*win
freq_block = torch.rfft(win_data_block, win_data_block.ndim, onesided=False)
Pss = torch.abs(freq_block)**2
Pss = torch.sum(Pss, 2)
Pss = Pss.double()
step1 = Pss-(k*Pvv)
step2 = step1*s
step3 = torch.exp(step2)
zeros = torch.zeros_like(step3)
step4 = torch.logsumexp([step3, zeros], 0)
step5 = step4/s
# train_curve = torch.log(1+torch.exp((Pss-k*Pvv)*s))/s
H = torch.max(step5, torch.zeros(Pss.size(), dtype=torch.double))
#H = H / normalised_Pss
# H[H != H] = 0
H = H.unsqueeze(2).repeat(1, 1, 2)
filt_freq_block = H*freq_block
filt_data_block = torch.irfft(filt_freq_block, win_data_block.ndim, onesided=False)
filt_data_block = (filt_data_block + mean_block*win) * win
# hbt = torch.round(hbt)
filt_data_block = torch.index_select(filt_data_block, 0, bcy)
filt_data_block = torch.index_select(filt_data_block, 1, bcx)
IR[rcy[0]:rcy[-1] + 1, rcx[0]:rcx[-1] + 1] = IR[rcy[0]:rcy[-1] + 1, rcx[0]:rcx[-1] + 1] + filt_data_block
return IR
and again the part I’m struggling with
step1 = Pss-(k*Pvv)
step2 = step1*s
step3 = torch.exp(step2)
zeros = torch.zeros_like(step3)
step4 = torch.logsumexp([step3, zeros], 0)
step5 = step4/s
# train_curve = torch.log(1+torch.exp((Pss-k*Pvv)*s))/s
H = torch.max(step5, torch.zeros(Pss.size(), dtype=torch.double))
H = H.unsqueeze(2).repeat(1, 1, 2)
filt_freq_block = H*freq_block
I’m not calling the function properly but I thought I’d show you what I’ve done so far. I don’t 100% understand what you mean when you say add a new dimension of length 2.
Thanks a million,
Clément