Pytorch Wavenet Model Training Loss Not Decreasing

Hi everyone,

I’ve been coding a wavenet model from scratch in pytorch, but for some reason, I just can’t get it to properly train. Every epoch in my code seems to have nearly the same loss, and I can’t seem to figure out why. I was hoping someone here would be able to take a look at my code and help me debug the situation.

With regards to the files:

[train.py]: this is where the training logic happens

[wavenet.py]: this is the pytorch model

[dataloader.py]: this is where I load and create the data to be used for training

[utils.py]: I have some generation functions here.

A lot of this code was adapted from sources I found online, the main difference being that I’m padding the input into the dilated convolution to ensure the input and output remain the same dimension.

Would really appreciate some help on this!

GITHUB: GitHub - mahtanir/Wavenet: Wavnet pytorch implementation

My model structure is below in case you don’t want to check out the github repo:

#x is typically one channel where the timestemps depends on the frequency rate. Can be two though. Typically a 1D array 
from torch import nn
import torch
import numpy as np
import torch.nn.functional as F


class CasualDilatedConv1D(nn.Module):
    def __init__(self, res_channels, out_channels, kernel_size, dilation):
        super().__init__()
        self.dilation = dilation
        self.kernel_size = kernel_size
        self.conv1D = nn.Conv1d(res_channels, out_channels, kernel_size, dilation=dilation, bias=False)
        self.ignoreOutIndex = (kernel_size - 1)*dilation #i.e don't have to consider right part because of k - 1 padding on either side. 
    
    def forward(self, x):
# Apply padding
        x = nn.functional.pad(x, ((self.kernel_size - 1)*self.dilation, 0)) #IF we don't need this need to add to input. 
        #if we do this without padding we lose (k - 1)*dim everytime. 
        #padding same is (k - 1) / 2 each side. So instead we do (k - 1) on both sides but now 2k-2 vs k - 1 so k - 1 extra. Remove right. 
        # x = x.double()
        # print('pre shape', x.shape)
        # print(self.conv1D(x)[..., :-self.ignoreOutIndex].shape) 
        return self.conv1D(x)
    # [..., :-self.ignoreOutIndex]  #https://chat.openai.com/c/0598fb53-ddb1-43e9-9572-8fc80498ca28 cause padding = same 
    #why do we do this? Only if we add padding right 

        #ALT
        # return self.conv1D(x)
    

class ResBlock(nn.Module): #using the same kernel weights for all 
    def __init__(self, res_channels, skip_channels, kernel_size, dilation):
        super().__init__()
        self.dilatedConv1D = CasualDilatedConv1D(res_channels, res_channels, kernel_size, dilation = dilation)
        self.resConv1D = nn.Conv1d(res_channels, res_channels, kernel_size=1, dilation=1) #i.e same input output dims (see diagram)
        self.skipConv1D = nn.Conv1d(res_channels, skip_channels, kernel_size=1, dilation=1)  
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
    
    

    def forward(self, input):
        x = self.dilatedConv1D(input)
        x_tan = self.tanh(x)
        x_sigmoid = self.sigmoid(x)
        x = x_tan * x_sigmoid
        residual_output = self.resConv1D(x) #shape = n,c,sample
        residual_output = residual_output +  input  
        # ALT residual_output = residual_output + input[..., -residual_output.size(2):] due to causality and dilated conv affecting dimensions.
        skip_output = self.skipConv1D(x) #this is for the skip connection output to the right in diagram 
        return residual_output, skip_output

class stackOfResBlocks(nn.Module):
    def __init__(self, stack_size, layer_size, res_channels, skip_channels, kernel_size):
        super().__init__()
        dilations = self.buildDilations(stack_size, layer_size)
        self.resBlockArr = [] 
        for stack in dilations:
            for dilation in stack:
                self.resBlockArr.append(ResBlock(res_channels, skip_channels, kernel_size, dilation))

    
    def buildDilations(self, stack_size, layer_size):
        dilations_arr_all = []
        for stack in range(stack_size): #stack is not actually a stack of resblocks but rather to 512.Could just do 1 array I feel but good logic
            dilation_arr = [] 
            for j in range(layer_size):
                dilation = 2**j if 2**j <= 520 else 520 #assuming doesn't go beyond 512 otherwise impose a cap ie 2**layer_size
                dilation_arr.append(dilation)
            dilations_arr_all.append(dilation_arr)
        return np.array(dilations_arr_all)   

    def forward(self, x):
        residual_outputs = [] 
        for resBlock in self.resBlockArr:
            x, residual = resBlock(x)
            residual_outputs.append(residual)
        return x, torch.stack(residual_outputs) #creates new dim at = 0 . so it is #layers, (n), samples, channels 

class DenseLayer(nn.Module): #WHAT IS GOING ON HERE! 
    def __init__(self, res_channels, out_channels):
        super().__init__()
        self.relu = nn.ReLU()
        self.conv1D = nn.Conv1d(res_channels, res_channels, kernel_size=1, dilation=1, bias=False)
        self.conv2nD = nn.Conv1d(res_channels, out_channels, kernel_size=1, dilation=1, bias=False)
        self.softmax = nn.Softmax(dim=1) 

    def forward(self, skipConnections): #not sure about channel here
        #we have skip connections of (batches, timesteps, channles) potentially channels is the timseteps and timestep is the song notes? 
        # based on medium article, it is (#layers, samples, channels)
        out = torch.sum(skipConnections, dim=0) #sum across the layers --> should be 0?? he put dim=2
        out = self.relu(out)
        out = self.conv1D(out)
        out = self.relu(out)
        out = self.conv2nD(out)
        return out 
        # return self.softmax(out) #outs dimensions after torch.sum become samples,channels in which case this would make sense. 


class Wavenet(nn.Module):
    def __init__(self, res_channels, out_channels, skip_channels, kernel_size, stack_size, layer_size): #stack sie and layer size depends how many we want to stack 
        super().__init__()
        self.stack_size = stack_size 
        self.layer_size = layer_size
        self.kernel_size = kernel_size

        self.casualConv1D = CasualDilatedConv1D(256, res_channels, kernel_size, dilation=1) #what are channels here? Represent different features audio wise. Usually just one channel to represent amplitude. 
        self.resBlockStack = stackOfResBlocks(stack_size, layer_size, res_channels, skip_channels, kernel_size)
        self.denseLayer = DenseLayer(skip_channels, out_channels)
    
    def calculateReceptiveField(self):
        sum_val = np.sum([(self.kernel_size - 1) * 2**self.layer_size for i in range(self.layer_size)] * self.stack_size)
        #would need this if we were not doing padding, i.e see image above. At each step we're removing (kernel - 1) * 2**layer # from previous resblock output. 
        return sum_val
    
    def forward(self, x):
        x = one_hot(x, self.kernel_size)
        x = self.casualConv1D(x)
        # print('conv1d 2/ dilation post shape: ', x.shape, '\n')
        final_res_output, skip_connections = self.resBlockStack(x) #final output is not necessary
        skip_output = sum([skip[...,-final_res_output.shape[-1]:] for skip in skip_connections]) #ALT
        return self.denseLayer(skip_connections)
    
# class WavenetClassifier(nn.Module):
#     def __init__(self, ):
#         super().__init__()
#         self.Wavenet = Wavenet(32, 256, 512, 2, 10,  5) #if we want to one hot encode the notes may have to convert this to 256. 
#         #32 = 24 in image, 512 = 128 in image

def one_hot(x, kernel_size):
    x = torch.tensor(np.array(x))
    # print('shape here!', x.shape)
    # x = nn.functional.pad(x, (kernel_size - 1, kernel_size - 1)) #IF we don't need this need to add to input. 
    # print('shape here!', x.shape)
    one_hot = F.one_hot(x, num_classes=256) 
    # print(one_hot.shape, one_hot[0])
    tf_shape = (1, -1, 256) #so rows actually are the points! I THINK! but then the way the conv channel works is weird... But image also shows like this 
    py_shape = (1, 256, -1)
    one_hot = torch.reshape(one_hot, py_shape)
    one_hot = torch.tensor(one_hot, dtype=torch.float32)
    # print('one_hot vector input: \n', one_hot, '\n', one_hot.shape, '\n')
    return one_hot

And the train file is as follows:

from wavenet import *  
# from wavenet_copy import *
import datetime
from dataloader import * 
import numpy as np
from mxnet import ndarray
from tqdm import tqdm
from scipy.io.wavfile import write
from IPython.display import Audio
import sounddevice as sd




class Train():

    def __init__(self, config) -> None:
        self.layer_size = config.layer_size 
        self.stack_size = config.stack_size #repeat 

        self.res_channels = config.res_channels
        self.skip_channels = config.skip_channels

        self.mu = 256 
        self.batch_size = config.batch_size
        self.epochs = config.epochs

        self.seq_length = config.seq_length 
    
    def save_params(self, model):
        torch.save(model.state_dict(), 'models/best2/wavenet.pth')
        # model.save_params('models/best_perf/' + datetime.datetime.now())

    def train(self):
        print('at train')
        # Wavenet()
        net = Wavenet(self.res_channels, self.mu, self.skip_channels, 2, self.stack_size, self.layer_size)
        self.net = net 
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
        n_steps = self.batch_size
        print('loading music...')
        fs, data = load_music('data_parametric-2')
        minLoss = None 
        data_generator = data_generation(data, fs, self.seq_length, self.mu, None) #generate training data lazilly.
        for i in tqdm(range(self.epochs)):
            loss = 0 
            for j in tqdm(range(self.batch_size), leave=False): #assuming that the batch size is full training set, stochastic gradient descent 
                # print('epoch: ', i, '\n sample: ', j, '\n')
                sample = next(data_generator)
                #Forward Pass
                x = sample[:-1] #one behind 
                # x = x.astype('float64')
                # print('type x', type(x[0]))
                y = sample[-x.shape[0]:] #normal (effectively one forward)

                y = one_hot_utils(y)

                y_hat = net(x) #converted to one_hot already but in the right format for conv 
                # print('model_output: ', y_hat, '\n model shape:', y_hat.shape,
                    #    '\n test output: ', y, '\n test output shape', y.shape)
                # print('shape check', y_hat.shape, y.shape)
                tf_shape = (0, 2, 1) #so rows actually are the points! I THINK! but then the way the conv channel works is weird... But image also shows like this 
                y_hat = torch.permute(y_hat, tf_shape)

                # print(y_hat.shape, y.shape)
                # loss_criterion = criterion(y_hat, y) 
                #if alt 
                loss_criterion = criterion(y_hat, y) #ALT
                loss = loss + loss_criterion.item()
                # print('loss criterion: ', loss_criterion, '\nloss: ', loss)
                loss_criterion.backward() #predicts loss across each step for all categories. Only true cat matters though.I.e loss is (1, sample) -> Actually loss criterion avg across samples 
                optimizer.step()
                optimizer.zero_grad() #stochastic 

            with torch.no_grad():
                 agg_loss = loss / self.batch_size
                 print(f"loss for epoch {i} : {agg_loss} \n")
            # ndarray.sum(loss).asscalar()
                 if (minLoss is None or agg_loss < minLoss): #stochastic volative, so look per batch which is best
                    minLoss = agg_loss
                    self.save_params(net)
        return net
    
    def bestModel(self): #load best model for given architecture 
        model =  Wavenet(self.res_channels, self.mu, self.skip_channels, 2, self.stack_size, self.layer_size)
        model.load_state_dict(torch.load('models/best2/wavenet.pth'))
        self.net = model
        return model 
    
    def generate_slow(self, x, model, n, dilation_depth, n_repeat):
         dilations = [2*i for i in  range(dilation_depth)] * n_repeat
         reference_window = sum(dilations)
         x_generated = x.copy()
         np.save("initial_wav_alt.npy",decode_mu_law(x_generated.copy()))
         for i in tqdm(range(n)):
              y = model(x_generated[-reference_window - 1:])
            #   y = model(x_generated[-reference_window -1: ]) ALR
              y_next = np.squeeze(y.argmax(1).numpy())[-1] #n, c, samples
            #   print(y.argmax(1).numpy(), np.squeeze(y.argmax(1).numpy())[-1])
              x_generated = np.append(x_generated, y_next)
              #similar to LSTM logic but now instead of reference window of 1, you have reference window of dilations 
              #still add the prev output! 
            #   print(x_generated.shape)
         return x_generated
    
    def generator(self, model, n, dilation_depth, n_repeat): 
         print('generating now...')
         fr, data = load_music('data_parametric-2')
         data_sample = data_generation_sample(data, fr, self.seq_length, self.mu, None)
         generated_song = self.generate_slow(data_sample, model, n, self.layer_size, self.stack_size)
         gen_wav = np.array(generated_song)
         decoded_wave = decode_mu_law(gen_wav, 256)
         np.save("wav_long_alt.npy",decoded_wave)
        #  write('test.wav', fr, decoded_wave)
        #  sd.play(decoded_wave, fr)

        #  Audio(gen_wav, rate=fr)

    

    

    
    # LOGIC: Note that for the last output, elements beyond the reference window don't affect the subsequent output. 
    # as such we simply need to consider the reference window only. Reference window is the sum of the dilations for last point. 
    #i.e always consider left most point contributing towards it! Always + dilation for reference window. 
    #But what is x too small ie starts from 0? Not sure about - 1 also but still works This is especially because of the 
        # 1by1 conv. Therefore since weight predecided we only really need to know of the last node, if 2 by 1 would need to know of the other one also i.e one before.  
    
    # def generate_slower(self, x, models, dilation_depth, n_repeat, ctx, n):
    #     dilations = [2**i for i in range(dilation_depth)] * n_repeat 
    #     res = list(x.asnumpy())
    #     for _ in trange(n):
    #         x = nd.array(res[-sum(dilations)-1:],ctx=ctx) i.e losing (k - 1)*dilation every time. So we sum all --> here only looking at 1 skip conn output
    #         y = models(x)
    #         res.append(y.argmax(1).asnumpy()[-1])
    #     return res

As for the utils file:


#borrowed from https://medium.com/@kion.kim/wavenet-a-network-good-to-know-7caaae735435

import numpy as np
from torch import nn
import torch.nn.functional as F
import torch


def encode_mu_law(x, mu=256):
    mu = mu-1
    fx = np.sign(x)*np.log(1+mu*np.abs(x))/np.log(1+mu) #1
    print('TYPE', type(x[0]))
    return np.floor((fx+1)/2*mu+0.5).astype(np.long) #2

def decode_mu_law(y, mu=256):
    mu = mu-1
    fx = (y-0.5)/mu*2-1 #reverse of #1
    x = np.sign(fx)/mu*((1+mu)**np.abs(fx)-1) 
    return x  

def one_hot_utils(x):
    x = torch.tensor(np.array(x))
    one_hot = F.one_hot(x, num_classes=256) 
    tf_shape = (1, -1, 256) #so rows actually are the points! I THINK! but then the way the conv channel works is weird... But image also shows like this 
    one_hot = torch.reshape(one_hot, tf_shape)
    one_hot = torch.tensor(one_hot, dtype=torch.float32)
    return one_hot

The dataloader file is as follows:

import os 
from scipy.io import wavfile
import numpy as np 
from utils import * 

def load_music(music_name):
    fs, data = wavfile.read(os.path.join('data', music_name + '.wav'))
    # print(fs, data)
    return fs, data

def load_music_test():
    music_name = 'data_parametric-2'
    fs, data = load_music(music_name)
    print(fs, data, data.shape)

def data_generation(data, frame_rate, seq_size, mu, ctx):
    max_val = max(abs(min(data)), max(data))
    data = data / max_val
    while True: #forever? 
        sequence_sample_start = np.random.randint(0, data.shape[0] - seq_size)
        subsequence = data[sequence_sample_start: sequence_sample_start + seq_size]
        condensed_subsequence = encode_mu_law(subsequence, mu)
        yield condensed_subsequence #yield returns a generator object that is an iterable that can be iterated on (i.e with for loop) only once
    # preserves memory since it doesn't store it in memory vs other iterables like arrays or lists (function continues where left off
    # ) Returns one value at a time, as long as it knows the next I think it's ok. 

def data_generation_sample(data, frame_rate, seq_size, mu, ctx):
    #same logic as before but now we only return one, not a generator. 
    max_val = max(max(data), abs(min(data)))
    data = data / max_val 
    start = np.random.randint(0, data.shape[0] - seq_size)
    subset = data[start: start+ seq_size]
    return encode_mu_law(subset, mu)





load_music_test()

self.resBlockArr is a plain Python list and thus not properly registered. Create an nn.ModuleList to optimize these layers properly.

Hi,

Thanks a ton for the help.

Unfortuantely, even after setting self.resBlockArr = nn.ModuleList(), the training error per epoch does not seem to be decreasing (it’s just stagnant). Any other thoughts as to why this may be?

I didn’t see any other obvious issues in your code, so you might want to make sure the computation graph is not detached and that all trainable parameters receive a valid gradient by checking their .grad attribute after the first backward call.
If this looks good, you could try to overfit a small subset (e.g. just 10 samples) of your dataset to make sure the model is capable of learning the data distribution.