【VAE】The R2 score of my VAE is very poor when training for signal denoising

I want to use VAE to denoise multi-channel signals(BS × 12 × 256, where 12 is the number of channels, i.e., 12 sensors, and 256 is the time window length of a single sample, i.e., the sampling step of sensors. BS × 256 is all the time steps collected by a single sensor). However, directly using traditional convolutional VAE requires flattening the data to 1D to achieve better results,the reconstruction effect is very poor when directly inputting the 12 channels. Therefore, I suspect that the fixed convolution kernel cannot learn the time pattern of the signals. Multiple-scale convolution kernels were added to the VAE, but the final R2 SCORE still remains around 0.15 and cannot increase. What is the situation?

My VAE code is as follows

class VAE(nn.Module):  # define the Vae model
    def __init__(self, dims, latent_dim):  
        super(VAE, self).__init__()  
        self.latent_dim = latent_dim  
        self.in_channel = dims

        # Define different kernel_size to learn signal feature-patterns as much as possible
        self.short_kernel = 2
        self.long_kernel = 15
        
        # Same as padding = 'SAME' in tensorflow 
        self.stride = 2
        pad_cal = lambda a: (a - 1) // 2
        
        # Not changing channels,to extract the feature along Time_length dimension 
        # Define the Conv sequentials of short_kernel and long_kernel
        self.short_kernel_conv = nn.Sequential(
            nn.Conv1d(self.in_channel, self.in_channel, kernel_size=self.short_kernel,
                      stride=self.stride, padding=pad_cal(self.short_kernel)),
            nn.BatchNorm1d(self.in_channel),
            nn.LeakyReLU(),
            nn.Conv1d(self.in_channel, self.in_channel,
                      kernel_size=self.short_kernel,
                      stride=self.stride, padding=pad_cal(self.short_kernel)),
            nn.BatchNorm1d(self.in_channel),
            nn.LeakyReLU(),
            nn.Conv1d(self.in_channel, self.in_channel,
                      kernel_size=self.short_kernel,
                      stride=self.stride, padding=pad_cal(self.short_kernel)),
            nn.BatchNorm1d(self.in_channel),
            nn.LeakyReLU(),
        )
        self.long_kernel_conv = nn.Sequential(
            nn.Conv1d(self.in_channel, self.in_channel, kernel_size=self.long_kernel,
                      stride=self.stride, padding=pad_cal(self.long_kernel)),
            nn.BatchNorm1d(self.in_channel),
            nn.LeakyReLU(),
            nn.Conv1d(self.in_channel, self.in_channel, kernel_size=self.long_kernel,
                      stride=self.stride, padding=pad_cal(self.long_kernel)),
            nn.BatchNorm1d(self.in_channel),
            nn.LeakyReLU(),
            nn.Conv1d(self.in_channel, self.in_channel, kernel_size=self.long_kernel,
                      stride=self.stride, padding=pad_cal(self.long_kernel)),
            nn.BatchNorm1d(self.in_channel),
            nn.LeakyReLU(),
                                              )
        # To reduce the dim of concreted channels 
        self.conv_2 = nn.Conv1d(2 * self.in_channel, self.in_channel, kernel_size=1, stride=1)
  
        self.conv_mu = nn.Conv1d(self.in_channel, self.in_channel, kernel_size=1, stride=1)
        self.conv_var = nn.Conv1d(self.in_channel, self.in_channel, kernel_size=1, stride=1) 
        
        # Define the ConvTranspose sequential
        self.transconv = nn.Sequential(
            nn.ConvTranspose1d(self.in_channel, self.in_channel, kernel_size=self.short_kernel,
                               stride=self.stride),
            nn.BatchNorm1d(self.in_channel),
            nn.ConvTranspose1d(self.in_channel, self.in_channel,
                               kernel_size=self.short_kernel,
                               stride=self.stride),
            nn.BatchNorm1d(self.in_channel),
            nn.ConvTranspose1d(self.in_channel, self.in_channel,
                               kernel_size=self.short_kernel,
                               stride=self.stride),
            nn.BatchNorm1d(self.in_channel),
        )
       
    def encode(self, x):  
        short_result = self.short_kernel_conv(x)
        long_result = self.long_kernel_conv(x)
		# to concrete the feature map of short_kernel and long_kernel
        concrete_result = torch.cat((short_result, long_result), dim=1)
        reduction_result = self.conv_2(concrete_result)
        mu = self.conv_mu(reduction_result)  
        log_var = self.conv_var(reduction_result)  
        return mu, log_var 

    def decode(self, z): 
        y = self.transconv(z)
        return y  

    def reparameterize(self, mu, log_var):  
        std = torch.exp(0.5 * log_var)  
        eps = torch.randn_like(std)  
        return mu + eps * std  

    def forward(self, x):  
        mu, log_var = self.encode(x)  
        z = self.reparameterize(mu, log_var)  
        y = self.decode(z)  
        return [y, x, mu, log_var]  


def loss_fn(y, x, mu, log_var):  
    recons_loss = loss_func(y, x)  
    kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return recons_loss + w * kld_loss  # w = 0.00025