I’m exploring a toy re-implementation of ConvTranspose2d
via torch.nn.Fold
in the same spirit as the Conv2d example implementation in the documentation of torch.nn.Unfold
I think this is pretty close: does anyone spot any problems (aside from parameter initialization)?
It seems to give results that look like there might be rather severe checkerboard artifacts compared to the native ConvTranspose2d
implementation, but I’m not sure if that’s possibly just untuned parameter initialization.
class ConvTranspose2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size,
stride=1,
padding=0,
output_padding=0,
dilation=1,
# groups=1,
bias=True,
# padding_mode="zeros",
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.output_padding = output_padding
self.dilation = dilation
self.linear = nn.Linear(
in_features=self.in_channels,
out_features=self.out_channels * self.kernel_size * self.kernel_size,
bias=bias,
)
def forward(self, x):
in_size = x.shape[2:]
# See ConvTranspose2d documentation for this formula
out_size = (
# Height
(in_size[0] - 1) * self.stride
- 2 * self.padding
+ self.dilation * (self.kernel_size - 1)
+ self.output_padding
+ 1,
# Width
(in_size[1] - 1) * self.stride
- 2 * self.padding
+ self.dilation * (self.kernel_size - 1)
+ self.output_padding
+ 1,
)
out_patches = self.linear(
# View channels as the last dimension (after collapsed Height x Width dimensions)
x.view(*x.shape[:-2], -1).transpose(-2, -1)
).view(
# View patches (number of sliding blocks as the last dimension)
*x.shape[:-3], self.out_channels * self.kernel_size * self.kernel_size, -1
)
return nn.functional.fold(
out_patches,
output_size=out_size,
kernel_size=self.kernel_size,
dilation=self.dilation,
padding=self.padding,
stride=self.stride,
)