Could someone help me reduce the loss on my autoencoder?

I have a UNET style autoencoder below, with a filter I wrote in Pytorch at the end. The network seems to be converging faster than it should and I don’t know why. I have a dataset of 4000 images and I’m taking a 128x128 crop every time. I’m employing a training rate schedule and weight decay. I’ve tried fiddling with my parameters with a tiny dataset to see improvements but nothing seems to work. Once the learning rate goes down, the loss just bounces around and doesn’t hit a floor, and in some cases goes back up. My network is as follows:

import torch
import torch.nn as nn
from wiener_3d import wiener_3d
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import random

def np_to_pil(np_imgs):
    '''Converts image in np.array format to PIL image.
    From C x W x H [0..1] to  W x H x C [0...255]
    '''
    img_num = np_imgs.shape[0]
    channel_num = np_imgs.shape[1]
    ar = np.clip(np_imgs*255, 0, 255).astype(np.uint8)

    pil_imgs = []
    for i in range(img_num):
        if channel_num == 1:
            img = ar[i][0]
        else:
            img = ar[i].transpose(1, 2, 0)
        pil_imgs.append(Image.fromarray(img))

    return pil_imgs


class WienerFilter(nn.Module):
    def __init__(self, param_b=16):
        super(WienerFilter, self).__init__()
        # self.register_parameter("param_a", nn.Parameter(torch.tensor(param_a)))
        # self.param_a = nn.Parameter(torch.tensor(param_a))
        # self.param_a.requires_grad = True
        self.param_b = param_b

    def forward(self, input, std):
        tensors = input.shape[0]
        for i in range(tensors):
            tensor = input[i]
            tensor = torch.squeeze(tensor)
            # tensor = wiener_3d(tensor, self.param_a, self.param_b
            tensor = wiener_3d(tensor, 2*std, self.param_b)
            tensor = torch.unsqueeze(tensor, 0)
            input[i] = tensor
        return input


class AutoEncoder(nn.Module):
    """Autoencoder simple implementation """
    def __init__(self):
        super(AutoEncoder, self).__init__()
        # Encoder
        # conv layer
        self.block1 = nn.Sequential(
            nn.Conv2d(1, 96, 3, padding=1),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(0.1),
            nn.Conv2d(96, 96, 3, padding=1),
            nn.MaxPool2d(2),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(0.1)

        )
        self.block2 = nn.Sequential(
            nn.Conv2d(96, 96, 3, padding=1),
            nn.MaxPool2d(2),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(0.1)
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(96, 96, 3, padding=1),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(0.1),
            nn.ConvTranspose2d(96, 96, 2, 2),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(0.1)
        )
        self.block4 = nn.Sequential(
            nn.Conv2d(192, 192, 3, padding=1),
            nn.BatchNorm2d(192),
            nn.LeakyReLU(0.1),
            nn.Conv2d(192, 192, 3, padding=1),
            nn.BatchNorm2d(192),
            nn.LeakyReLU(0.1),
            nn.ConvTranspose2d(192, 192, 2, 2),
            nn.BatchNorm2d(192),
            nn.LeakyReLU(0.1)
        )
        self.block5 = nn.Sequential(
            nn.Conv2d(288, 192, 3, padding=1),
            nn.BatchNorm2d(192),
            nn.LeakyReLU(0.1),
            nn.Conv2d(192, 192, 3, padding=1),
            nn.BatchNorm2d(192),
            nn.LeakyReLU(0.1),
            nn.ConvTranspose2d(192, 192, 2, 2),
            nn.BatchNorm2d(192),
            nn.LeakyReLU(0.1)
        )
        self.block6 = nn.Sequential(
            nn.Conv2d(193, 96, 3, padding=1),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(0.1),
            nn.Conv2d(96, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(32, 1, 3, padding=1),
            nn.LeakyReLU(0.1)
        )

        self.wiener_filter = WienerFilter()

    def forward(self, x, std):
        # torch.autograd.set_detect_anomaly(True)
        # print("input: ", x.shape)
        pool1 = self.block1(x)
        # print("pool1: ", pool1.shape)
        pool2 = self.block2(pool1)
        # print("pool2: ", pool2.shape)
        pool3 = self.block2(pool2)
        # print("pool3: ", pool3.shape)
        pool4 = self.block2(pool3)
        # print("pool4: ", pool4.shape)
        pool5 = self.block2(pool4)
        # print("pool5: ", pool5.shape)
        upsample5 = self.block3(pool5)
        # print("upsample5: ", upsample5.shape)
        concat5 = torch.cat((upsample5, pool4), 1)
        # print("concat5: ", concat5.shape)
        upsample4 = self.block4(concat5)
        # print("upsample4: ", upsample4.shape)
        concat4 = torch.cat((upsample4, pool3), 1)
        # print("concat4: ", concat4.shape)
        upsample3 = self.block5(concat4)
        # print("upsample3: ", upsample3.shape)
        concat3 = torch.cat((upsample3, pool2), 1)
        # print("concat3: ", concat3.shape)
        upsample2 = self.block5(concat3)
        # print("upsample2: ", upsample2.shape)
        concat2 = torch.cat((upsample2, pool1), 1)
        # print("concat2: ", concat2.shape)
        upsample1 = self.block5(concat2)
        # print("upsample1: ", upsample1.shape)
        concat1 = torch.cat((upsample1, x), 1)
        # print("concat1: ", concat1.shape)
        output = self.block6(concat1)
        path = "test"
        t_map = x - output

        name = random.randrange(0, 101, 2)                 # Even integer from 0 to 100 inclusive
        #plt.imshow(torch.squeeze(t_map, 0).permute(1,2,0))
        #  plt.imshow(torch.squeeze(t_map.permute(1, 2, 0), 0))
        pil_img = np_to_pil(t_map.detach().cpu().numpy())
        pil_img[0].save(path+str("/")+str(name)+str("_2_tmap.png"))
        pil_img_out = np_to_pil(output.detach().cpu().numpy())  
        pil_img_out[0].save(path+str("/")+str(name)+str("_1_net_out.png"))
        filtering = self.wiener_filter(t_map, std) 
        pil_img_out = np_to_pil(filtering.detach().cpu().numpy())  
        pil_img_out[0].save(path+str("/")+str(name)+str("_3_tmapfiltered.png"))
        filtered_output = output + filtering
        pil_img_out = np_to_pil(filtered_output.detach().cpu().numpy())  
        pil_img_out[0].save(path+str("/")+str(name)+str("_5_final.png"))
        return filtered_output
        

My current parameters are: Adam optimizer, learning rate decay by 0.1 if no improvement for 7 epochs, intial learning rate 0.001, 0.0001 weight decay, No batches.

I feel like I’ve tried everything at this stage. Could someone give me some advice on how to improve my network? Thank you.