PyTorch Wavenet Model loss is not decreasing (Help)

I’m a college student trying to implement Wavenet using PyTorch, this is my first time writing custom modules for a model in PyTorch and I’m having a problem with my model in that it won’t train. Essentially the project is this: I have a set of wav files that I am reading in, processing and quantizing as in the Wavenet paper, and am arranging into series of 1024 data points (the model takes a series of 1024 amplitudes from the wav as input and should output a tensor of 256 values describing the probability that the next item in the series is one of those 256 values).

I’m currently trying to train the model on a single music file, hoping to get it to overfit so that I can be sure that it actually learns from the data. Here lies the problem: the loss won’t decrease as I train. I’ve tried making the model smaller, changing the loss function, changing what kind of layer the output layer is, and making the learning rate larger but nothing seems to work.

I suspect that the problem is somewhere in the model itself, the way that it is constructed may be wrong. It’s possible that I could have linked the custom modules together with the model in a way that interferes with back propagation, at least that is my best guess. My code for the model and my training code is below. I would really appreciate some help!

#model https://github.com/Dankrushen/Wavenet-PyTorch/blob/master/wavenet/models.py
#https://github.com/ryujaehun/wavenet/blob/master/wavenet/networks.py
#https://medium.com/@satyam.kumar.iiitv/understanding-wavenet-architecture-361cc4c2d623
#https://discuss.pytorch.org/t/causal-convolution/3456/4

import torch
import torch.optim as optim
from torch import nn
from functools import reduce

#causal convolution (citation above)
class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, **kwargs):
        super(CausalConv1d, self).__init__()
        self.pad = (kernel_size - 1) * dilation
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.dilation = dilation
        
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=self.pad, dilation=dilation, **kwargs)
        
    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, skip_size, skip_channels, dilation=1):
        super(ResidualBlock, self).__init__()
        self.dilation = dilation
        self.skip_size = skip_size
        self.conv_s = CausalConv1d(input_channels, output_channels, kernel_size, dilation)#dim
        self.sig = nn.Sigmoid()
        self.conv_t = CausalConv1d(input_channels, output_channels, kernel_size, dilation)#dim
        self.tanh = nn.Tanh()
        self.conv_1 = nn.Conv1d(output_channels, output_channels, 1)#dim -> k = 1
        
        self.skip_conv = nn.Conv1d(output_channels, skip_channels, 1)
        
    def forward(self, x):
        o = self.sig(self.conv_s(x)) * self.tanh(self.conv_t(x))
        skip = self.skip_conv(o)
        skip = skip[:,:,-self.skip_size:] #dim control for adding skips
        residual = self.conv_1(o)
        
        return residual, skip
   
    

class WaveNet(nn.Module):#SET SKIP SIZE default
    def __init__(self, skip_size=256, num_blocks=2, num_layers=10, num_hidden=128, kernel_size=2): 
        super(WaveNet, self).__init__()
        self.layer1 = CausalConv1d(1, num_hidden, kernel_size)#dim
        self.res_stack = nn.ModuleList()
        
        for b in range(num_blocks):
            for i in range(num_layers):
                self.res_stack.append(ResidualBlock(num_hidden, num_hidden, kernel_size, skip_size=skip_size, skip_channels=1, dilation=2**i))#dim
        
        #self.hidden = nn.ModuleList(self.hidden)
        
        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv1d(1,1,1)#dim
        self.relu2 = nn.ReLU()
        self.conv2 = nn.Conv1d(1,1,1)#dim
        self.output = nn.Softmax()
        
    def forward(self, x):
        skip_vals = []
        #initial causal conv
        o = self.layer1(x)
        
        #run res blocks
        for i, layer in enumerate(self.res_stack):
            o, s = layer(o)
            skip_vals.append(s)
            
        #sum skip values and pass to last portion of network
        o = reduce((lambda a,b: a+b), skip_vals)
        o = self.relu1(o)
        o = self.conv1(o)
        o = self.relu2(o)
        o = self.conv2(o)
        
        return self.output(o)
    

        
#overfit model to test if it will train
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
net = WaveNet(num_layers=1)

#send to gpu
net.to(device)

criterion = nn.CrossEntropyLoss() #preproc data, stream, train, remember to reformat y label as onehot vector for classification
optimizer = optim.Adam(net.parameters(),lr=0.001)
num_epochs = 20
losses = []

_, inp = wav_to_data(data_path+'/'+wav_files[0])
data = encode(inp)

batch_size = 32

i = 0
for epoch in range(num_epochs):
    i += 1
    for s in range(0, len(data) - 1024, batch_size):
        (x, y) = create_singular_input_stream(data, s, batch_size)
        optimizer.zero_grad()

        output = net(torch.reshape(torch.FloatTensor(x).to(device), (batch_size,1,1024)))
        print(output.shape)
        print(torch.Tensor([y]).shape)

        #find loss between distributions of amplitudes
        #https://discuss.pytorch.org/t/indexerror-target-1-is-out-of-bounds-nlloss/68656
        loss = criterion(torch.squeeze(output), torch.squeeze(torch.Tensor(y)).type(torch.LongTensor).to(device))
        
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        
        print('Epoch {}/{}, Loss: {:.6f}'.format(i, num_epochs, loss.item()))
        

A helpful note:
create_singular_input_stream returns x (1,1024 tensor with series) and y (next series value [0,255])

Again, any help is appreciated - I’d love to know what I’m doing wrong.

@Arham_Khan

Just some mistakes I have saw:

  1. The padding in CausalConv1d is not causal. You’re going to pad (kernel_size - 1) * dilation on both side of the input’s last dimension. The correct way should be padding the input before the Conv1d layer in the forward call and set padding to 0 in Conv1d.
...
    self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=0, dilation=dilation, **kwargs)

def forward(self, x):
    x = F.pad(x, (self.pad, 0))       # only pad on the left side
    return self.conv(x)
  1. The skip_size param doesn’t make any sense to me, especially in this line:
skip = skip[:,:,-self.skip_size:] #dim control for adding skips

You set skip_size to 256 to match the one-hot vector of y, which represent the value of the sample just next to 1024 input samples. But you take 256 values from the last axis which represent time, so you’re somehow mapping 256 different time position values to 256 different amplitude values, looks weird to me :thinking:.
Those 256 hidden values should be taken from the second axis, which represent hidden channels.

  1. You directly cast x to float type, which make the quantization completely useless. In my opinion, the shape of x should be (batch, 256, 1024) with one-hot vector along the second axis, and the shape of y should be (batch, 256, 1) with one-hot vector along the second axis.

I think you might have some misunderstanding of WaveNet, but this is normal, because the original paper didn’t write everything very clearly (to me), so I recommend you to take a look at others’ implementation to get a wider view of WaveNet.

The wavenet model I have implemented:

Thanks for the thorough explanation! Sorry for the late reply but I’ve been busy with finals. The skip size parameter is something that has confused me as well but I saw it pop up in various implementations, including some which may be linked in my original post. It seems that they are pulling skip values for the last set of fully connected layers to act on, but I’m confused as to why we would truncate those vectors.

If I remember correctly, I had casted X to float because some of the layers in the model would not process X without it being a float. Is what you recommended essentially that for each sequence element I should pass a one-hot encoded vector to the model? So the model would be fed B batches, and each sequence of 1024 that I wished to pass to the model would be represented by 1024 one-hot encoded vectors of length 256, to represent the encoding space?

Could you comment on how the above change affects training? The way it currently works is that the input is quantized but even in float form the elements of the vector will still be of the form (120. , 243., …). In addition, how would we then process the one hot encoded vectors using convolutions? I see in your implementation you include an Embedding element, what is the use of this element in the context of wavenet? I see that it is a lookup table from the class to the embedding but what is it trained to embed each class into? Basically, how is the Embedding layer used in your implementation?

Also, I now see why the padding on the causal convolution is incorrect - admittedly I had pulled that code from another post on these forums.

Are there any resources and books you’d recommend to get a good understanding of machine learning theory and model development? Particularly form the point of view of someone teaching how to design novel model architectures?

I use Embedding layer just to map 256 classes to differenct latent vectors with size equals to residual channels, no other meaning. If you feed one-hot encoded vectors into the first layer it’s actually do the same thing as embedding. Directly feeding the quatized value will lose some complexity on the input presentation but may not affect the performance too much, I guess.

The trucation you see in other implementations I guess is to truncate those time steps that are not presented in the output sequence. Take your code for example, your truncation should be:

skip = skip[:,:,-1:] 

so the output shape is the same as (batch, 256, 1) just like your target.

Also, your current training method is quite inefficient. WaveNet is a fully convolutional model, a fully convolutional model means it can be parallized when training. My suggestion is, you can try feeding a sequence of samples as output instead of taking just one time sample. For example, if I want to use 5000 samples time sequences named x during training, my input and target output will be:

input = x[:, :, :4999]
target = x[:, :, 1:]          #right shift one sample

The truncation step can be removed because now the model output and target has the same shape.

Hope these will help you.

I now understand how the embedding layers function in your implementation, but in what sense would feeding the quantized values directly “lose complexity on the input presentation”? What sort of considerations for model performance lead you to that conclusion?

Also, in the more efficient training method you proposed, are the dimensions you are assuming for x, (batch, 256, seq_len)? Is it normal to create a one-hot-encoding by having the dimensions be (num_classes, 1) for each encoded vector? I ask because in PyTorch I thought that this would mean that we have 256 channels, each with a value representing the 0,1 class value - or in the case of the output the probability that the next item belongs to each class.

The reason I ask the above question is that I am currently using CrossEntropyLoss which was described in the Wavenet paper. According to the PyTorch documentation, this expects input with dimensions (minibatch, num_classes) and a target describing a class index (in my case: [0,255]). Given this, I would think I should make the one-hot-encoded vectors such that the input has dimensions (batch, seq_len, 256) and the output should be (batch, 1, 256) or (batch, 256).

How would using the above dimensions affect the model, given that it is using convolutional layers? Would they still work? Am I misunderstanding CrossEntropyLoss?

Thank you for the help! I’ve been able to make more progress in the last week conceptually and in code than I have over the past few months with your help. I really appreciate it!

I now understand how the embedding layers function in your implementation, but in what sense would feeding the quantized values directly “lose complexity on the input presentation”? What sort of considerations for model performance lead you to that conclusion?

Let’s give you a simple example.
If we feeding the value directly, the input channel size is 1 and the first layer map them to size of residual channel, so the parameter size in the first layer is (res_channels, 1, *), it’s actually just one vector; if we use one-hot encoding (or embedding layer), the parameter size of first layer is (res_channels, 256, *).
You can see that, the first method is actually just scale the amplitude of a hidden vector base on the input, but the second method the model will try to learn 256 different hidden vector for each quantized input value, which has much more freedom than the first method.

Also I want to point out why the first method might still work in this case is that our input data is waveform, although being mulaw-quantized, its raw value still represent some kind of information; but if our input is some kinds of discrete data, like discrete tokens in language processing, its category number can be permutaded and have no extra information like amplitude.

Also, in the more efficient training method you proposed, are the dimensions you are assuming for x, (batch, 256, seq_len)? Is it normal to create a one-hot-encoding by having the dimensions be (num_classes, 1) for each encoded vector? I ask because in PyTorch I thought that this would mean that we have 256 channels, each with a value representing the 0,1 class value - or in the case of the output the probability that the next item belongs to each class.

Exactly, or you can just feed (batch, seq_len) with integer type, and map them to size (batch, seq_len, 256) using embedding layer.

The reason I ask the above question is that I am currently using CrossEntropyLoss which was described in the Wavenet paper. According to the PyTorch documentation, this expects input with dimensions (minibatch, num_classes) and a target describing a class index (in my case: [0,255]). Given this, I would think I should make the one-hot-encoded vectors such that the input has dimensions (batch, seq_len, 256) and the output should be (batch, 1, 256) or (batch, 256).

Yes, or if you use the paralle training method I suggested, your output can also be (batch, seq_len, 256) with target as (batch, seq_len).

I see, that explanation about the encoding makes a lot of sense. Are there any resources describing the parallelization of CNNs in PyTorch? I couldn’t find anything in the documentation other than resources talking about general parallelization using some specific PyTorch modules.

Thanks again for the help, my model loss is decreasing steadily now - just have to beat some loss plateaus. I also now understand the implementation of the architecture much better!

1 Like

I’m running into another problem with my model during training and generation now and I was wondering whether you could provide any insight on that:

#model https://github.com/Dankrushen/Wavenet-PyTorch/blob/master/wavenet/models.py
#https://github.com/ryujaehun/wavenet/blob/master/wavenet/networks.py
#https://medium.com/@satyam.kumar.iiitv/understanding-wavenet-architecture-361cc4c2d623
#https://discuss.pytorch.org/t/causal-convolution/3456/4

import torch
import torch.optim as optim
import torch.nn.functional as F
from torch import nn
from functools import reduce

#causal convolution (citation above)
class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, **kwargs):
        super(CausalConv1d, self).__init__()
        self.pad = (kernel_size - 1) * dilation
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.dilation = dilation

        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size , dilation=dilation, **kwargs)
        
    def forward(self, x):
        #pad here to only add to the left side
        x = F.pad(x, (self.pad, 0))
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, skip_channels, dilation=1):
        super(ResidualBlock, self).__init__()
        self.dilation = dilation
        self.conv_sig = CausalConv1d(input_channels, output_channels, kernel_size, dilation)#dim
        self.sig = nn.Sigmoid()
        self.conv_tan = CausalConv1d(input_channels, output_channels, kernel_size, dilation)#dim
        self.tanh = nn.Tanh()
        
        #separate weights for residual and skip channels
        self.conv_r = nn.Conv1d(output_channels, output_channels, 1)#dim -> k = 1
        self.conv_s = nn.Conv1d(output_channels, skip_channels, 1)
        
    def forward(self, x):
        o = self.sig(self.conv_sig(x)) * self.tanh(self.conv_tan(x))
        skip = self.conv_s(o)
        #print("SKIP: " + str(skip.shape))
        #skip = skip[:,-self.skip_size:,:] #dim control for adding skips
        #skip = skip[:,:,-1:]
        #print("SK: " + str(skip.shape))
        residual = self.conv_r(o)
        #print("RES: " + str(residual.shape))
        
        return residual, skip
   
    

class WaveNet(nn.Module):
    def __init__(self, skip_channels=256, num_blocks=3, num_layers=10, num_hidden=256, kernel_size=2): 
        super(WaveNet, self).__init__()

        self.embed = nn.Embedding(skip_channels, skip_channels)
        self.layer1 = CausalConv1d(skip_channels, num_hidden, kernel_size)
        self.res_stack = nn.ModuleList()

        for b in range(num_blocks):
            for i in range(num_layers):
                self.res_stack.append(ResidualBlock(num_hidden, num_hidden, kernel_size, skip_channels=skip_channels, dilation=2**i))
        
        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv1d(skip_channels, skip_channels, 1)
        self.relu2 = nn.ReLU()
        self.conv2 = nn.Conv1d(skip_channels, skip_channels, 1)
        #self.output = nn.Softmax()
        
    def forward(self, x):

        o = self.embed(x)
        print('EX: ' + str(o.size()))
        dims = o.size()
        o = o.reshape(dims[0], dims[2], dims[3])
        o = o.transpose(1,2)
        print('ETX: ' + str(o.size()))

        skip_vals = []
        #initial causal conv
        o = self.layer1(o)
        
        #run res blocks
        for i, layer in enumerate(self.res_stack):
            o, s = layer(o)
            skip_vals.append(s)
            
        #sum skip values and pass to last portion of network
        o = reduce((lambda a,b: a+b), skip_vals)
        o = self.relu1(o)
        o = self.conv1(o)
        o = self.relu2(o)
        o = self.conv2(o)
        
        return o #self.output(o)
    

data_path = '../input/musicwav/80643 Delon  Dalcan - Panik (Original Mix).wav'
model_path = 'model_3B_6L.pt'

#num iterations after which we should save the model
R = 500

#overfit model to test if it will train
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
net = WaveNet(num_blocks=3, num_layers=6)

#send to gpu
net.to(device)

criterion = nn.CrossEntropyLoss() #preproc data, stream, train, remember to reformat y label as onehot vector for classification
optimizer = optim.Adam(net.parameters(),lr=0.001)
num_epochs = 1#00
losses = []

total_steps = 0

if os.path.exists(model_path):
    state = load_model(model_path)
    net.load_state_dict(state['state_dict'])
    optimizer.load_state_dict(state['optimizer'])
    total_steps = state['total_steps']

_, inp = wav_to_data(data_path)
data = encode(inp)
print(data.shape)

batch_size = 64
seq_len = 1024
net.train()

from math import floor
n_steps = floor(len(data) / batch_size)#seq_len)

i = 0
for epoch in range(num_epochs):
    i += 1
    t = 0
    for s in range(0, len(data) - seq_len - batch_size, batch_size):
        t += 1
        #(x, y) = create_singular_input_stream(data, s, batch_size)
        x = []
        y = []
        for g in range(batch_size):
            x.append(data[s+g:s+g+seq_len])
            y.append(data[s+g+1:s+g+seq_len+1])

        #print("X: " + str(torch.Tensor(x).shape))
        optimizer.zero_grad()

        output = net(torch.reshape(torch.LongTensor(x).to(device), (batch_size,1,seq_len)))
        #print(output.shape)
        #print(torch.Tensor(y).shape)

        loss = criterion(output, torch.Tensor(y).type(torch.LongTensor).to(device))

        loss.backward()
        optimizer.step()
        losses.append(loss.item())

        #save model for future training after set amount of iterations
        if t % R == 0:
            save_model(net, optimizer, total_steps+t, model_path)

        print('Epoch {}/{}, Timestep: {}/{}, Loss: {:.6f}'.format(i, num_epochs, t, n_steps, loss.item()))
        



import torch.nn.functional as F

#generate unguided audio
#take first sequence as input and continuously generate a sequence using it
model_path = 'model_3B_6L.pt'
seq_len = 1024
num_samples = 1323000 - seq_len

out_path = 'gen.wav'

#grab first sequence to start with
curr_seq = data[:1024]
generated_seq = curr_seq

net = WaveNet(num_blocks=3, num_layers=6)
state = load_model(model_path)
net.load_state_dict(state['state_dict'])

net.eval()

for i in range(num_samples):
    print('Sample: ' + str(i+1) + ' / ' + str(num_samples))
    output = net(torch.reshape(torch.LongTensor(curr_seq).to(device), (1,1,seq_len)))
    timestep_next = output[0,:,-1:]
    timestep_next = torch.squeeze(timestep_next)
    timestep_next = F.softmax(timestep_next)
    print(timestep_next.size())
    print(timestep_next.cpu().detach())
    t_next = np.argmax(timestep_next.cpu().detach().numpy(), axis=0)
    print(t_next)
    
    #append to sequence
    generated_seq = generated_seq + [t_next]
    
    
#write out file
timeseries_to_wav(np.array(generated_seq), out_path)


When I’m training the model, the loss fluctuates a good deal (though this could be because of the small batch size of 64 I’m using due to memory constraints) but generally does trend downwards. At inference time however, I find that the model produces a constant output of 0. I’ve read around that this could be caused by incorrect use of the quantized inputs, I’ve used the embedding layer rather than passing the input straight to the first convolutional layer as you suggested, but I’m unsure if somewhere I’m making a mistake. One candidate could be the reshaping I am doing in forward() to correct the dimensions of the tensor, another thing I noticed was that you used a TanH activation after your embedding layer which I did not understand - is there a purpose to using an activation function on an embedding vector?

You are feeding the same input curr_seq in each step.

The Tanh() function I added is a design choice adopted from Deep Voice 3.

Wow, I was thinking about way to complicated a set of problems. Thank you again for all the help, the model is functioning much better now - hopefully it will work with enough training!