PyTorch Wavenet Model loss is not decreasing (Help)

@Arham_Khan

Just some mistakes I have saw:

  1. The padding in CausalConv1d is not causal. You’re going to pad (kernel_size - 1) * dilation on both side of the input’s last dimension. The correct way should be padding the input before the Conv1d layer in the forward call and set padding to 0 in Conv1d.
...
    self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=0, dilation=dilation, **kwargs)

def forward(self, x):
    x = F.pad(x, (self.pad, 0))       # only pad on the left side
    return self.conv(x)
  1. The skip_size param doesn’t make any sense to me, especially in this line:
skip = skip[:,:,-self.skip_size:] #dim control for adding skips

You set skip_size to 256 to match the one-hot vector of y, which represent the value of the sample just next to 1024 input samples. But you take 256 values from the last axis which represent time, so you’re somehow mapping 256 different time position values to 256 different amplitude values, looks weird to me :thinking:.
Those 256 hidden values should be taken from the second axis, which represent hidden channels.

  1. You directly cast x to float type, which make the quantization completely useless. In my opinion, the shape of x should be (batch, 256, 1024) with one-hot vector along the second axis, and the shape of y should be (batch, 256, 1) with one-hot vector along the second axis.

I think you might have some misunderstanding of WaveNet, but this is normal, because the original paper didn’t write everything very clearly (to me), so I recommend you to take a look at others’ implementation to get a wider view of WaveNet.

The wavenet model I have implemented: