Is my network causing poor denoising performance?

Hi folks, I’ve been working on image denoising for a couple of months but I can’t seem to solve the poor performance of my denoiser. Below is my model. Please refer to the class DnCnn at the bottom.

class FCN(nn.Module):
    def __init__(self):
        super(FCN, self).__init__()
        self.fcn = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 3, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.fcn(x)


class resblock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.strided_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, stride=1, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        conv_block = self.strided_conv(x)
        if conv_block.size() == x.size():
            out = x + conv_block
            return out
        else:
            return conv_block


class dncnn_block(nn.Module):
    def __init__(self, in_channels, nc, out_channels):
        super().__init__()
        self.resblock1 = resblock(in_channels, nc)
        self.resblock2 = resblock(nc, nc)
        self.resblock3 = resblock(nc, nc)
        self.resblock4 = resblock(nc, nc)
        self.resblock5 = resblock(nc, nc)
        self.resblock6 = resblock(nc, nc)
        self.resblock7 = resblock(nc, nc)
        self.resblock8 = resblock(nc, nc)
        self.resblock9 = resblock(nc, out_channels)

    def forward(self, x):
        layer1 = self.resblock1(x)
        layer2 = self.resblock2(layer1)
        layer3 = self.resblock3(layer2)
        layer4 = self.resblock4(layer3)
        layer5 = self.resblock5(layer4)
        layer6 = self.resblock6(layer5)
        layer7 = self.resblock7(layer6)
        layer8 = self.resblock8(layer7)
        layer9 = self.resblock9(layer8)

        return layer9


class DnCNN(nn.Module):
    def __init__(self, in_nc=6, out_nc=3, nc=64, nb=20, act_mode='BR'):
        super(DnCNN, self).__init__()
        assert 'R' in act_mode or 'L' in act_mode
        bias = False
        self.fcn = FCN()
        dncnn1 = dncnn_block(nc, nc, nc)
        head1 = conv(in_nc, nc, mode='C'+act_mode[-1], bias=bias)
        tail1 = conv(nc, out_nc, mode='C', bias=bias)
        self.model1 = sequential(head1, dncnn1, tail1)

    def forward(self, x, train_mode=True):
        noise_level = self.fcn(x)
        concat_img = torch.cat([x, noise_level], 1)
        level1_out = self.model1(concat_img) + x
        return noise_level, level1_out

I call DnCNN as my network. It is made up of different blocks. I have excluded the conv() function but it is just a function that returns specified layers. The input is the noisy image x. The idea is that there is a mini network/ancillary network called fcn() whose output is merged with the noisy image x and is passed through the main network. The original image is then added back onto the predicted: residual image level1_out = self.model1(concat_img) + x to return a denoised image.

I am returning “noise_map” as I have ground truth noise maps

which represent the severity of the noise. As such being able to determine the noise map is a means of knowing how much noise is in the image prior to denoising.

My problem is that when the network is done training, the result is not very smooth (literally) on very fine Gaussian noise. This is unexpected because I am using identical training parameters to networks that perform well in this regard. I have tried with many different datasets also, including the same used in other state of the art denoising solutions. I’ve run out of experiments that I can think of that may be causing the problem. I am posting this here in case someone sees something improper in the way the network is built.

Here is a sample of the results. From left to right: Original noisy image, my poorly denoised image and two similar denoisers which have desirable smooth results. I will say that my denoiser performs ok on images that have large flat regions. Like colour boards, books and illustrations.

Thanks folks.

edit: I forgot to say I’m training it on x3 2080 TIs with torch.nn.DataParallel().