ConvTranspose2D re-implementation via Fold

I’m exploring a toy re-implementation of ConvTranspose2d via torch.nn.Fold in the same spirit as the Conv2d example implementation in the documentation of torch.nn.Unfold

I think this is pretty close: does anyone spot any problems (aside from parameter initialization)?

It seems to give results that look like there might be rather severe checkerboard artifacts compared to the native ConvTranspose2d implementation, but I’m not sure if that’s possibly just untuned parameter initialization.

class ConvTranspose2d(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size,
        stride=1,
        padding=0,
        output_padding=0,
        dilation=1,
        # groups=1,
        bias=True,
        # padding_mode="zeros",
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.output_padding = output_padding
        self.dilation = dilation

        self.linear = nn.Linear(
            in_features=self.in_channels,
            out_features=self.out_channels * self.kernel_size * self.kernel_size,
            bias=bias,
        )

    def forward(self, x):
        in_size = x.shape[2:]

        # See ConvTranspose2d documentation for this formula
        out_size = (
            # Height
            (in_size[0] - 1) * self.stride
            - 2 * self.padding
            + self.dilation * (self.kernel_size - 1)
            + self.output_padding
            + 1,
            # Width
            (in_size[1] - 1) * self.stride
            - 2 * self.padding
            + self.dilation * (self.kernel_size - 1)
            + self.output_padding
            + 1,
        )

        out_patches = self.linear(
            # View channels as the last dimension (after collapsed Height x Width dimensions)
            x.view(*x.shape[:-2], -1).transpose(-2, -1)
        ).view(
            # View patches (number of sliding blocks as the last dimension)
            *x.shape[:-3], self.out_channels * self.kernel_size * self.kernel_size, -1
        )

        return nn.functional.fold(
            out_patches,
            output_size=out_size,
            kernel_size=self.kernel_size,
            dilation=self.dilation,
            padding=self.padding,
            stride=self.stride,
        )

I think I probably found the problems (described for posterity).

Problem 1:

The code above relies on self.linear to add a bias to the output patches. However, it looks like the bias needs be added after nn.functional.fold to avoid grid patterned artifacts in the final output.

Problem 2:

Output patches was not reshaped correctly, it needed to be transposed as follows

out_patches = self.linear(
    x.view(*x.shape[:-2], -1).transpose(-2, -1) # B, (H × W), Cin
).transpose(-2,-1) # B, (Cout × K × K), (H x W)