Causal Convolution

Hello,

Can you recommend an idea of simple implementation of Causal Convolution 1D (aka masked convolution) used by WaveNet?

Thank you.

1 Like

Hi,

when I tried to do ByteNet in torch, I used the following residual unit. The conv_dilated has padding in input and output to make it causal.
I must admit that I’m not 100% sure whether it works, because the larger thing does not quite work as expected and I did not get around to seeing whether it was the architecture itself or the training that fooled it. My experiments with the conv layer alone seemed to see it working.

As such, I hope the below is useful and would be grateful if you had feedback whether you think it works as expected.
There is an alternative to “overpadding and cutting off” in “padding before the convolution” (using torch.nn.functional.pad), but I don’t know which is better.

Best regards

Thomas

class ResUnit(nn.Module):
    def __init__(self, in_channels, size=3, dilation=1, causal=False, in_ln=True):
        super(ResUnit, self).__init__()
        self.size = size
        self.dilation = dilation
        self.causal = causal
        self.in_ln = in_ln
        if self.in_ln:
            self.ln1 = nn.InstanceNorm1d(in_channels, affine=True)
            self.ln1.weight.data.fill_(1.0)
        self.conv_in = nn.Conv1d(in_channels, in_channels//2, 1)
        self.ln2 = nn.InstanceNorm1d(in_channels//2, affine=True)
        self.ln2.weight.data.fill_(1.0)
        self.conv_dilated = nn.Conv1d(in_channels//2, in_channels//2, size, dilation=self.dilation,
                                      padding=((dilation*(size-1)) if causal else (dilation*(size-1)//2)))
        self.ln3 = nn.InstanceNorm1d(in_channels//2, affine=True)
        self.ln3.weight.data.fill_(1.0)
        self.conv_out = nn.Conv1d(in_channels//2, in_channels, 1) 

    def forward(self, inp):
        x = inp
        if self.in_ln:
            x = self.ln1(x)
        x = nn.functional.leaky_relu(x)
        x = nn.functional.leaky_relu(self.ln2(self.conv_in(x)))        
        x = self.conv_dilated(x)
        if self.causal and self.size>1:
            x = x[:,:,:-self.dilation*(self.size-1)]
        x = nn.functional.leaky_relu(self.ln3(x))
        x = self.conv_out(x)
        return x+inp
1 Like

I’d be interested in this as well. I’ve been trying to implement Wavenet and I ended up doing a hacky 1d padding before the convolution. Seems to work, but I am not confident this is the best solution. I created a new padding layer for 3d tensors to do this. You can see my implementation here.

Can’t you just set the padding in the Conv1d to ensure the convolution in causal? This is probably more efficient that explicitly padding the input:

def CausalConv1d(in_channels, out_channels, kernel_size, dilation=1, **kwargs):
   pad = (kernel_size - 1) * dilation
   return nn.Conv1d(in_channels, out_channels, kernel_size, padding=pad, dilation=dilation, **kwargs)
   
...

class Network(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = CausalConv1d(256, 256, kernel_size=3, dilation=2)

def forward(self, x):
   ...
   x = self.conv1(x)
   x = x[:, :, :-self.conv1.padding[0]]  # remove trailing padding
   ...
   return x

If you really want, you could subclass Conv1d to remove the padding in the forward() call.

You can check that the convolution is causal with:

>>> m = CausalConv1d(1, 1, kernel_size=3, dilation=2, bias=False)
>>> x = torch.autograd.Variable(torch.randn(1, 1, 9))
>>> x[:, :, :3].data.zero_() # make the first three elements zero
>>> print(x)
>>> print(m(x))  # first three outputs should be zero
8 Likes

I think you are right. I’ll try to implement this and if it doesn’t break anything.

@colesbury et al. shouldn’t it be this ?

pad = (kernel_size -1) dilation + 1

For example, consider kernel size 1 and dilation 1 or kernel size 2 and dilation 1. In these cases padding should be 1 and 2 but given the equation you provided they are 0 and 1.

It seems to me that kernel size 1 needs padding 0 because at time t, the conv will see [x_t] and output [y_t]. No padding needed.

With kernel size 2, dilation 1, left_padding 1, the conv at time t will see [x_tm1, x_t] and output [y_t]. Padding 1 is sufficient.

1 Like

Probably what I understand by causal is wrong.
I thought causal means not using the future nor the present, i.e. p(Yi | Xj < i), instead of what you say, i.e. (Yi | Xj <= i),

Generally Yi is the target for Xj <= i.

If I follow your logic, then Y0 = target for Xj < 0, and you have one element in your target array that doesn’t correspond to any input.

Even for time series prediction, the convention is to set Yi = Xi-1.

Yeah, one common solution when predicting the first element is to pad the input with zeros and use the padded input to predict Y_0.

Sergei’s post explains it in the context of Wavenet. The first convolution is padded such that the model doesn’t use the current sample to predict the current sample.

After the first convolution, we then have a structure that seems to be using P(Y_i | X_{j<=i}) when in fact it’s really P(Y_i | E_i, X_{i-1}), that is, the probability of the current sample given the embeddings of the current sample and the previous sample used to create the embedding.
I’m abusing notation.

5 Likes

There’s a good WaveNet implementation in PyTorch from Nov 2019 in the Seq-U-Net repo. It includes Dilated Causal Convolutions. Source: Seq-U-Net/wavenet_model.py at master · f90/Seq-U-Net · GitHub

from torch import nn
import torch

class WaveNetModel(nn.Module):
    """
    A Complete Wavenet Model
    Args:
        layers (Int):               Number of layers in each block
        blocks (Int):               Number of wavenet blocks of this model
        dilation_channels (Int):    Number of channels for the dilated convolution
        residual_channels (Int):    Number of channels for the residual connection
        skip_channels (Int):        Number of channels for the skip connections
        classes (Int):              Number of possible values each sample can have
        output_length (Int):        Number of samples that are generated for each input
        kernel_size (Int):          Size of the dilation kernel
        dtype:                      Parameter type of this model
    Shape:
        - Input: :math:`(N, C_{in}, L_{in})`
        - Output: :math:`()`
        L should be the length of the receptive field
    """
    def __init__(self,
                 layers=10,
                 blocks=5,
                 dilation_channels=32,
                 residual_channels=32,
                 skip_channels=512,
                 classes=256,
                 output_length=32,
                 kernel_size=2,
                 dtype=torch.FloatTensor,
                 bias=False,
                 fast=False):

        super(WaveNetModel, self).__init__()

        self.layers = layers
        self.blocks = blocks
        self.dilation_channels = dilation_channels
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels
        self.classes = classes
        self.kernel_size = kernel_size
        self.dtype = dtype
        self.fast = fast

        # build model
        receptive_field = 1
        init_dilation = 1

        self.dilations = []

        self.filter_convs = nn.ModuleList()
        self.gate_convs = nn.ModuleList()
        self.residual_convs = nn.ModuleList()
        self.skip_convs = nn.ModuleList()

        # 1x1 convolution to create channels
        self.start_conv = nn.Conv1d(in_channels=self.classes,
                                    out_channels=residual_channels,
                                    kernel_size=1,
                                    bias=bias)

        for b in range(blocks):
            additional_scope = kernel_size - 1
            new_dilation = 1
            for i in range(layers):
                # dilations of this layer
                self.dilations.append((new_dilation, init_dilation))

                # dilated convolutions
                self.filter_convs.append(nn.Conv1d(in_channels=residual_channels,
                                                   out_channels=dilation_channels,
                                                   kernel_size=kernel_size,
                                                   bias=bias,
                                                   dilation=new_dilation))

                self.gate_convs.append(nn.Conv1d(in_channels=residual_channels,
                                                 out_channels=dilation_channels,
                                                 kernel_size=kernel_size,
                                                 bias=bias,
                                                 dilation=new_dilation))

                # 1x1 convolution for residual connection
                self.residual_convs.append(nn.Conv1d(in_channels=dilation_channels,
                                                     out_channels=residual_channels,
                                                     kernel_size=1,
                                                     bias=bias))

                # 1x1 convolution for skip connection
                self.skip_convs.append(nn.Conv1d(in_channels=dilation_channels,
                                                 out_channels=skip_channels,
                                                 kernel_size=1,
                                                 bias=bias))

                receptive_field += additional_scope
                additional_scope *= 2
                init_dilation = new_dilation
                new_dilation *= 2

        self.end_conv_1 = nn.Conv1d(in_channels=skip_channels,
                                  out_channels=skip_channels,
                                  kernel_size=1,
                                  bias=True)

        self.end_conv_2 = nn.Conv1d(in_channels=skip_channels,
                                    out_channels=classes,
                                    kernel_size=1,
                                    bias=True)

        # self.output_length = 2 ** (layers - 1)
        self.output_size = output_length
        self.receptive_field = receptive_field
        self.input_size = receptive_field + output_length - 1

    def forward(self, input, mode="normal"):
        if mode == "save":
            self.inputs = [None]* (self.blocks * self.layers)

        x = self.start_conv(input)
        skip = 0

        # WaveNet layers
        for i in range(self.blocks * self.layers):

            #            |----------------------------------------|     *residual*
            #            |                                        |
            #            |    |-- conv -- tanh --|                |
            # -> dilate -|----|                  * ----|-- 1x1 -- + -->	*input*
            #                 |-- conv -- sigm --|     |
            #                                         1x1
            #                                          |
            # ---------------------------------------> + ------------->	*skip*

            (dilation, init_dilation) = self.dilations[i]

            if mode == "save":
                self.inputs[i] = x[:,:,-(dilation*(self.kernel_size-1) + 1):]
            elif mode == "step":
                self.inputs[i] = torch.cat([self.inputs[i][:,:,1:], x], dim=2)
                x = self.inputs[i]

            # dilated convolution
            residual = x

            filter = self.filter_convs[i](x)
            filter = torch.tanh(filter)
            gate = self.gate_convs[i](x)
            gate = torch.sigmoid(gate)
            x = filter * gate

            # parametrized skip connection
            s = self.skip_convs[i](x)
            if skip is not 0:
                skip = skip[:, :, -s.size(2):]
            skip = s + skip

            x = self.residual_convs[i](x)
            x = x + residual[:, :, dilation * (self.kernel_size - 1):]

        x = torch.relu(skip)
        x = torch.relu(self.end_conv_1(x))
        x = self.end_conv_2(x)

        return x
1 Like

One thing to point out is to add a if statement like below to make it robust to padding=0:

      if self.causal_conv.padding[0] != 0:
         x = x[:, :, :-self.causal_conv.padding[0]]  # remove trailing padding

class CausalConv1d(nn.Module):
“”"
causal conv1d
return the sequence with the same length after
1D causal convolution
“”"
def init(self, in_channels, out_channels, kernel_size,
dilation):
super(CausalConv1d, self).init()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = dilation*(kernel_size-1)
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
padding=self.padding, dilation=dilation)

def forward(self, x):
    """
    shape of x: [total_seq, num_features, num_timesteps]
    """
    x = self.conv(x)
    return x[:,:,:-self.padding]

Just a late follow up question on this:

It seems to me that symmetrically padding and then cutting away the trailing padding is
a convenience workaround for avoiding asymmetrically padding. Wouldn’t it be better (and by that I mean computationally faster) to apply asymmetric padding to the start of the sequence using F.pad before convolution. Then, the kernel outputs for the right hand side paddings which are cut away anyways would not have to be computed. This would save some unnecessary computation, no?

E.g. if we apply a kernel k=2 with dilation d=1, we would have p = (k-1)*d = 1
The sequence

ABCD

would then be padded like this:

PABCDP

This would result in kernel output:

PA,AB,BC,CD,DP

and we would then cut away [-p:] = [-1:], thus remove the element DP, which would have been unnecessarily computed and end up with the kernel output to add up:

PA,AB,BC,CD

I guess it is the question which one is faster; padding asymmetrically in advance with F.pad, or padding symmetrically in Conv2d() and then removing [-p:]. Anybody tried that? Or is my understanding of the causal padding trick wrong?

Thanks! Best, JZ