Hello,
Can you recommend an idea of simple implementation of Causal Convolution 1D (aka masked convolution) used by WaveNet?
Thank you.
Hello,
Can you recommend an idea of simple implementation of Causal Convolution 1D (aka masked convolution) used by WaveNet?
Thank you.
Hi,
when I tried to do ByteNet in torch, I used the following residual unit. The conv_dilated
has padding in input and output to make it causal.
I must admit that I’m not 100% sure whether it works, because the larger thing does not quite work as expected and I did not get around to seeing whether it was the architecture itself or the training that fooled it. My experiments with the conv layer alone seemed to see it working.
As such, I hope the below is useful and would be grateful if you had feedback whether you think it works as expected.
There is an alternative to “overpadding and cutting off” in “padding before the convolution” (using torch.nn.functional.pad), but I don’t know which is better.
Best regards
Thomas
class ResUnit(nn.Module):
def __init__(self, in_channels, size=3, dilation=1, causal=False, in_ln=True):
super(ResUnit, self).__init__()
self.size = size
self.dilation = dilation
self.causal = causal
self.in_ln = in_ln
if self.in_ln:
self.ln1 = nn.InstanceNorm1d(in_channels, affine=True)
self.ln1.weight.data.fill_(1.0)
self.conv_in = nn.Conv1d(in_channels, in_channels//2, 1)
self.ln2 = nn.InstanceNorm1d(in_channels//2, affine=True)
self.ln2.weight.data.fill_(1.0)
self.conv_dilated = nn.Conv1d(in_channels//2, in_channels//2, size, dilation=self.dilation,
padding=((dilation*(size-1)) if causal else (dilation*(size-1)//2)))
self.ln3 = nn.InstanceNorm1d(in_channels//2, affine=True)
self.ln3.weight.data.fill_(1.0)
self.conv_out = nn.Conv1d(in_channels//2, in_channels, 1)
def forward(self, inp):
x = inp
if self.in_ln:
x = self.ln1(x)
x = nn.functional.leaky_relu(x)
x = nn.functional.leaky_relu(self.ln2(self.conv_in(x)))
x = self.conv_dilated(x)
if self.causal and self.size>1:
x = x[:,:,:-self.dilation*(self.size-1)]
x = nn.functional.leaky_relu(self.ln3(x))
x = self.conv_out(x)
return x+inp
I’d be interested in this as well. I’ve been trying to implement Wavenet and I ended up doing a hacky 1d padding before the convolution. Seems to work, but I am not confident this is the best solution. I created a new padding layer for 3d tensors to do this. You can see my implementation here.
Can’t you just set the padding in the Conv1d to ensure the convolution in causal? This is probably more efficient that explicitly padding the input:
def CausalConv1d(in_channels, out_channels, kernel_size, dilation=1, **kwargs):
pad = (kernel_size - 1) * dilation
return nn.Conv1d(in_channels, out_channels, kernel_size, padding=pad, dilation=dilation, **kwargs)
...
class Network(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = CausalConv1d(256, 256, kernel_size=3, dilation=2)
def forward(self, x):
...
x = self.conv1(x)
x = x[:, :, :-self.conv1.padding[0]] # remove trailing padding
...
return x
If you really want, you could subclass Conv1d to remove the padding in the forward()
call.
You can check that the convolution is causal with:
>>> m = CausalConv1d(1, 1, kernel_size=3, dilation=2, bias=False)
>>> x = torch.autograd.Variable(torch.randn(1, 1, 9))
>>> x[:, :, :3].data.zero_() # make the first three elements zero
>>> print(x)
>>> print(m(x)) # first three outputs should be zero
I think you are right. I’ll try to implement this and if it doesn’t break anything.
@colesbury et al. shouldn’t it be this ?
pad = (kernel_size -1) dilation + 1
For example, consider kernel size 1 and dilation 1 or kernel size 2 and dilation 1. In these cases padding should be 1 and 2 but given the equation you provided they are 0 and 1.
It seems to me that kernel size 1 needs padding 0 because at time t, the conv will see [x_t] and output [y_t]. No padding needed.
With kernel size 2, dilation 1, left_padding 1, the conv at time t will see [x_tm1, x_t] and output [y_t]. Padding 1 is sufficient.
Probably what I understand by causal is wrong.
I thought causal means not using the future nor the present, i.e. p(Yi | Xj < i), instead of what you say, i.e. (Yi | Xj <= i),
Generally Yi is the target for Xj <= i.
If I follow your logic, then Y0 = target for Xj < 0, and you have one element in your target array that doesn’t correspond to any input.
Even for time series prediction, the convention is to set Yi = Xi-1.
Yeah, one common solution when predicting the first element is to pad the input with zeros and use the padded input to predict Y_0.
Sergei’s post explains it in the context of Wavenet. The first convolution is padded such that the model doesn’t use the current sample to predict the current sample.
After the first convolution, we then have a structure that seems to be using P(Y_i | X_{j<=i}) when in fact it’s really P(Y_i | E_i, X_{i-1}), that is, the probability of the current sample given the embeddings of the current sample and the previous sample used to create the embedding.
I’m abusing notation.
There’s a good WaveNet implementation in PyTorch from Nov 2019 in the Seq-U-Net repo. It includes Dilated Causal Convolutions. Source: Seq-U-Net/wavenet_model.py at master · f90/Seq-U-Net · GitHub
from torch import nn
import torch
class WaveNetModel(nn.Module):
"""
A Complete Wavenet Model
Args:
layers (Int): Number of layers in each block
blocks (Int): Number of wavenet blocks of this model
dilation_channels (Int): Number of channels for the dilated convolution
residual_channels (Int): Number of channels for the residual connection
skip_channels (Int): Number of channels for the skip connections
classes (Int): Number of possible values each sample can have
output_length (Int): Number of samples that are generated for each input
kernel_size (Int): Size of the dilation kernel
dtype: Parameter type of this model
Shape:
- Input: :math:`(N, C_{in}, L_{in})`
- Output: :math:`()`
L should be the length of the receptive field
"""
def __init__(self,
layers=10,
blocks=5,
dilation_channels=32,
residual_channels=32,
skip_channels=512,
classes=256,
output_length=32,
kernel_size=2,
dtype=torch.FloatTensor,
bias=False,
fast=False):
super(WaveNetModel, self).__init__()
self.layers = layers
self.blocks = blocks
self.dilation_channels = dilation_channels
self.residual_channels = residual_channels
self.skip_channels = skip_channels
self.classes = classes
self.kernel_size = kernel_size
self.dtype = dtype
self.fast = fast
# build model
receptive_field = 1
init_dilation = 1
self.dilations = []
self.filter_convs = nn.ModuleList()
self.gate_convs = nn.ModuleList()
self.residual_convs = nn.ModuleList()
self.skip_convs = nn.ModuleList()
# 1x1 convolution to create channels
self.start_conv = nn.Conv1d(in_channels=self.classes,
out_channels=residual_channels,
kernel_size=1,
bias=bias)
for b in range(blocks):
additional_scope = kernel_size - 1
new_dilation = 1
for i in range(layers):
# dilations of this layer
self.dilations.append((new_dilation, init_dilation))
# dilated convolutions
self.filter_convs.append(nn.Conv1d(in_channels=residual_channels,
out_channels=dilation_channels,
kernel_size=kernel_size,
bias=bias,
dilation=new_dilation))
self.gate_convs.append(nn.Conv1d(in_channels=residual_channels,
out_channels=dilation_channels,
kernel_size=kernel_size,
bias=bias,
dilation=new_dilation))
# 1x1 convolution for residual connection
self.residual_convs.append(nn.Conv1d(in_channels=dilation_channels,
out_channels=residual_channels,
kernel_size=1,
bias=bias))
# 1x1 convolution for skip connection
self.skip_convs.append(nn.Conv1d(in_channels=dilation_channels,
out_channels=skip_channels,
kernel_size=1,
bias=bias))
receptive_field += additional_scope
additional_scope *= 2
init_dilation = new_dilation
new_dilation *= 2
self.end_conv_1 = nn.Conv1d(in_channels=skip_channels,
out_channels=skip_channels,
kernel_size=1,
bias=True)
self.end_conv_2 = nn.Conv1d(in_channels=skip_channels,
out_channels=classes,
kernel_size=1,
bias=True)
# self.output_length = 2 ** (layers - 1)
self.output_size = output_length
self.receptive_field = receptive_field
self.input_size = receptive_field + output_length - 1
def forward(self, input, mode="normal"):
if mode == "save":
self.inputs = [None]* (self.blocks * self.layers)
x = self.start_conv(input)
skip = 0
# WaveNet layers
for i in range(self.blocks * self.layers):
# |----------------------------------------| *residual*
# | |
# | |-- conv -- tanh --| |
# -> dilate -|----| * ----|-- 1x1 -- + --> *input*
# |-- conv -- sigm --| |
# 1x1
# |
# ---------------------------------------> + -------------> *skip*
(dilation, init_dilation) = self.dilations[i]
if mode == "save":
self.inputs[i] = x[:,:,-(dilation*(self.kernel_size-1) + 1):]
elif mode == "step":
self.inputs[i] = torch.cat([self.inputs[i][:,:,1:], x], dim=2)
x = self.inputs[i]
# dilated convolution
residual = x
filter = self.filter_convs[i](x)
filter = torch.tanh(filter)
gate = self.gate_convs[i](x)
gate = torch.sigmoid(gate)
x = filter * gate
# parametrized skip connection
s = self.skip_convs[i](x)
if skip is not 0:
skip = skip[:, :, -s.size(2):]
skip = s + skip
x = self.residual_convs[i](x)
x = x + residual[:, :, dilation * (self.kernel_size - 1):]
x = torch.relu(skip)
x = torch.relu(self.end_conv_1(x))
x = self.end_conv_2(x)
return x
One thing to point out is to add a if statement like below to make it robust to padding=0
:
if self.causal_conv.padding[0] != 0:
x = x[:, :, :-self.causal_conv.padding[0]] # remove trailing padding
class CausalConv1d(nn.Module):
“”"
causal conv1d
return the sequence with the same length after
1D causal convolution
“”"
def init(self, in_channels, out_channels, kernel_size,
dilation):
super(CausalConv1d, self).init()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = dilation*(kernel_size-1)
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
padding=self.padding, dilation=dilation)
def forward(self, x):
"""
shape of x: [total_seq, num_features, num_timesteps]
"""
x = self.conv(x)
return x[:,:,:-self.padding]
Just a late follow up question on this:
It seems to me that symmetrically padding and then cutting away the trailing padding is
a convenience workaround for avoiding asymmetrically padding. Wouldn’t it be better (and by that I mean computationally faster) to apply asymmetric padding to the start of the sequence using F.pad before convolution. Then, the kernel outputs for the right hand side paddings which are cut away anyways would not have to be computed. This would save some unnecessary computation, no?
E.g. if we apply a kernel k=2 with dilation d=1, we would have p = (k-1)*d = 1
The sequence
ABCD
would then be padded like this:
PABCDP
This would result in kernel output:
PA,AB,BC,CD,DP
and we would then cut away [-p:] = [-1:], thus remove the element DP, which would have been unnecessarily computed and end up with the kernel output to add up:
PA,AB,BC,CD
I guess it is the question which one is faster; padding asymmetrically in advance with F.pad, or padding symmetrically in Conv2d() and then removing [-p:]. Anybody tried that? Or is my understanding of the causal padding trick wrong?
Thanks! Best, JZ