Causal Convolution

Can’t you just set the padding in the Conv1d to ensure the convolution in causal? This is probably more efficient that explicitly padding the input:

def CausalConv1d(in_channels, out_channels, kernel_size, dilation=1, **kwargs):
   pad = (kernel_size - 1) * dilation
   return nn.Conv1d(in_channels, out_channels, kernel_size, padding=pad, dilation=dilation, **kwargs)
   
...

class Network(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = CausalConv1d(256, 256, kernel_size=3, dilation=2)

def forward(self, x):
   ...
   x = self.conv1(x)
   x = x[:, :, :-self.conv1.padding[0]]  # remove trailing padding
   ...
   return x

If you really want, you could subclass Conv1d to remove the padding in the forward() call.

You can check that the convolution is causal with:

>>> m = CausalConv1d(1, 1, kernel_size=3, dilation=2, bias=False)
>>> x = torch.autograd.Variable(torch.randn(1, 1, 9))
>>> x[:, :, :3].data.zero_() # make the first three elements zero
>>> print(x)
>>> print(m(x))  # first three outputs should be zero
9 Likes