Hi, I want to implement a casual 3d conv to process video sequences, with “replicate” padding in spatial and “zero” padding in temporal. Here is my implementation, and I apply this casual conv in some residual blocks, due to the limit of memory, I use util.checkpoint to wrap the residual module.
class TemporalCasualConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, spatial_padding_mode="replicate"):
super().__init__()
self.kernel_size = format_tuple_size("kernel_size", kernel_size) # int to tuple
self.stride = format_tuple_size("stride", stride)
self.dilation = dilation
self.padding_dim = tuple((k - 1) * dilation for k in self.kernel_size)
self.f_pad_param = (self.padding_dim[2] // 2 + self.padding_dim[2] % 2, self.padding_dim[2] // 2,
self.padding_dim[1] // 2 + self.padding_dim[1] % 2, self.padding_dim[1] // 2,
0, 0)
self.spatial_padding_mode = spatial_padding_mode
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=(self.padding_dim[0], 0, 0), dilation=dilation)
def forward(self, x):
"""
input: x[B, C_in, T, H, W]
output: x[B, C_out, T', H', W']
"""
_, _, t0, _, _ = x.shape
x = self.conv(F.pad(x, self.f_pad_param, mode=self.spatial_padding_mode))
last_slice = (t0 + self.padding_dim[0] - self.dilation * (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1
if self.padding_dim[0] > 0:
x = x[:, :, :last_slice, :, :]
return x
I’m not sure whether my implementation is correct, especially in backward and under checkpoint mechanism. I notice this module decreases the metrics a lot (3x loss) compared with vanilla 3d convolution with replicate padding.
Can anyone provide some instructions?