Can’t you just set the padding in the Conv1d to ensure the convolution in causal? This is probably more efficient that explicitly padding the input:
def CausalConv1d(in_channels, out_channels, kernel_size, dilation=1, **kwargs):
pad = (kernel_size - 1) * dilation
return nn.Conv1d(in_channels, out_channels, kernel_size, padding=pad, dilation=dilation, **kwargs)
...
class Network(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = CausalConv1d(256, 256, kernel_size=3, dilation=2)
def forward(self, x):
...
x = self.conv1(x)
x = x[:, :, :-self.conv1.padding[0]] # remove trailing padding
...
return x
If you really want, you could subclass Conv1d to remove the padding in the forward()
call.
You can check that the convolution is causal with:
>>> m = CausalConv1d(1, 1, kernel_size=3, dilation=2, bias=False)
>>> x = torch.autograd.Variable(torch.randn(1, 1, 9))
>>> x[:, :, :3].data.zero_() # make the first three elements zero
>>> print(x)
>>> print(m(x)) # first three outputs should be zero