How does ConvTranspose3D work

Hi everyone,

Hope you are safe and well.

I am trying to get the inverse of a Conv3d operation by using the ConvTranspose3d.

My input has the size of [128,128,12,16,12] and I want to get an output of size [128,256, 24,32,24].

I successfully done it using torch.nn.ConvTranspose3d(128, 256, 3, 2, padding=1, output_padding=1), but I do not really understand how padding and output_padding really work.

I read the second note in the documentation of ConvTranspose3d but still can’t manage to get a grip of the influence they have on each other.

What I have tried to get the right size is:

x = torch.randn(128, 128, 12, 16, 12)
print(x.shape)
c1 = torch.nn.ConvTranspose3d(128, 256, 3, 2)
print(c1(x).shape)
c1 = torch.nn.ConvTranspose3d(128, 256, 3, 2, padding=1)
print(c1(x).shape)
c1 = torch.nn.ConvTranspose3d(128, 256, 3, 2, padding=1, output_padding=1)
print(c1(x).shape)

Can you please explain to me how they really work?

Thank you and stay safe!

As the note say, padding serves to correspond to the padding argument in convs.
Now, when you have a nontrivial stride, you typically have that there are several input sizes to the convolution giving the same output size (with the larger one having the last row(s)/column(s) ignored). This is adjusted for by the output_padding argument.

As simple 2d example:

w = torch.randn(1, 1, 3, 3)
w2 = torch.randn_like(w)
x = torch.randn(1, 1, 4, 4)
y1 = torch.nn.functional.conv2d(x, w, stride=2)
y2 = torch.nn.functional.conv2d(x[:, :, :-1, :-1], w, stride=2)
print(y1, "\n", y2)
z1 = torch.nn.functional.conv_transpose2d(y1, w2, stride=2, output_padding=1)
z2 = torch.nn.functional.conv_transpose2d(y2, w2, stride=2, output_padding=0)
print(z1, "\n", z2)

As we see,

  • without output padding, we get the “just fits” input size,
  • with output padding we get a row and column of zeros for the values that the conv would have ignored.

Note that transposed convolutions are more an adjoint (in the linear map sense) than an inverse operation.

Best regards

Thomas

Thanks a lot for the swift response Thomas!

So, if I understood correctly what you said is that:

  1. The convolution layer is many-to-one mapping due to the VALID (tensorflow nomenclature) behavior combined with the possible padding and stride that are given. This does not allow the function to be a bijective function, thus we do not have an inverse to it.

  2. For deconvolution, we have one-to-many mapping and while padding behaves the same as for convolution, output_padding tries to match the input to one of the many possible output shapes by padding the output shape of its output_padding=0 counterpart before starting the computing and then make sure the computing results in an output shape that is “padded” with the output_padding that was given to it.

Am I right?

Best regards and stay safe,

Dan

I’m not sure that (my mathematical idea of) many-to-one and one-to-many mapping might be applicable here. To my mind, more than the VALID mode, the key bit here is the stride and the question “how many translated copies of the convolution stencil fit into the input data” and whether there are columns/rows left over that are too few to fit another stencil copy. These then don’t show up in the convolution output and need to be adjusted for in to get the transposed convolution to output the right shape.

Is the mathematical correlation between the input and desired output mapping the one outlined in Relationship 14 on page 26 of A guide to convolution arithmetic for deep learning?

Yes, and that’s the same formula as given in the ConvTranspose2d documentation in the “Shape:” section specialized to dilation = 1.

Best regards

Thomas

The solution that I found for the stride of 2 is:

torch.nn.ConvTranspose3d(
    in_channels=number_of_filters,
    out_channels=number_of_filters,
    kernel_size=kernel_size,
    stride=2,
    padding=kernel_size // 2,
    output_padding=kernel_size % 2,
)

When I have free time I will think of how to do it for strides that are different from 2, but for now, I need to adapt it to non-isotropic kernels.

Thanks a lot Thomas for all the help and pointers!

Stay safe!

Correction to my solution:

torch.nn.ConvTranspose3d(
    in_channels=number_of_filters,
    out_channels=number_of_filters,
    kernel_size=kernel_size,
    stride=2,
    padding=kernel_size//2 - (1 - kernel_size % 2),
    output_padding=kernel_size % 2,
)

Hi Thomas,

Sorry to bother you again. I have tried to code Relationship 14 for the 2D case but I can’t make it work due to a PyTorch restriction ( output_pading < stride). Do you think you have the time to have a look over it?

import torch


def calculate_deconv_params(
    input_shape, kernel_shape, padding_shape, stride_shape
):
    '''
    Calculates the padding and output_padding for a torch.nn.ConvTransposeXD to 
    invert the effect of a torch.nn.ConvXD

    Args:
        input_shape: The input shape of tensor GIVEN TO torch.nn.ConvXD
        kernel_shape: The kernel shape passed to the torch.nn.ConvXD's kernel_size
        padding_shape: The padding shape passed to the torch.nn.ConvXD's padding
        stride_shape: The stride shape passed ot the torch.nn.ConvXD's stride

    Returns:
        A dictionary of with the following format:
            {
                "kernel_size": k',
                "stride": s',
                "output_padding": a,
                "padding": p',
            }

    References:
        Dumoulin, V. and Visin, F., 2016.
        A guide to convolution arithmetic for deep learning.
        arXiv preprint arXiv:1603.07285.
    '''
    # k' = k
    deconv_kernel_shape = kernel_shape

    # This deviates from Relationship 14 as I assume that the torch.nn.ConvTransposeXD
    # uses the stride to actually insert zeroes in the input and then continues with a
    # stride of 1
    deconv_stride_shape = stride_shape

    # p' = k - p - 1
    deconv_padding_shape = [
        kernel_shape[i] - padding_shape[i] - 1 for i in range(len(kernel_shape))
    ]

    # a = (i + 2 * p - k) % s
    deconv_output_padding_shape = [
        (input_shape[i + 2] + 2 * padding_shape[i] - kernel_shape[i])
        % stride_shape[i]
        for i in range(len(kernel_shape))
    ]

    return {
        "kernel_size": deconv_kernel_shape,
        "stride": deconv_stride_shape,
        "output_padding": deconv_output_padding_shape,
        "padding": deconv_padding_shape,
    }


x = torch.rand(1, 1, 6, 6)

print("Working example")
kernel_size = (3, 3)
padding = (1, 1)
stride = (2, 2)

print(x.shape)

c1 = torch.nn.Conv2d(
    in_channels=1,
    out_channels=1,
    kernel_size=kernel_size,
    stride=stride,
    padding=padding,
)
print(c1(x).shape)

c2_params = calculate_deconv_params(x.shape, kernel_size, padding, stride)
c2_params.update({"in_channels": 1, "out_channels": 1})
print(c2_params)

c2 = torch.nn.ConvTranspose2d(**c2_params)
print(c2(c1(x)).shape)


print("Breaking example")
kernel_size = (2, 2)
padding = (1, 1)
stride = (2, 2)

print(x.shape)

c1 = torch.nn.Conv2d(
    in_channels=1,
    out_channels=1,
    kernel_size=kernel_size,
    stride=stride,
    padding=padding,
)
print(c1(x).shape)

c2_params = calculate_deconv_params(x.shape, kernel_size, padding, stride)
c2_params.update({"in_channels": 1, "out_channels": 1})
print(c2_params)

c2 = torch.nn.ConvTranspose2d(**c2_params)
print(c2(c1(x)).shape)

Thanks a lot and stay safe!

cc @ptrblck

I think you’re overthinking this and except for the output padding you can just pass in the input’s parameters (notably padding) to get the right result.