The output size of ConvTranspose2d differs from the expected output size

In documentation, there are the code annotations as below.

input = autograd.Variable(torch.randn(1, 16, 12, 12))
downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1)
upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
h = downsample(input)
h.size()         # (1, 16, 6, 6)
output = upsample(h, output_size=input.size())
output.size()    # (1, 16, 12, 12)

However, the actual output.size() returns (1, 16, 11, 11). Is this normal? I think it should be (1, 16, 12, 12) because ConvTranspose2d is just opposite operation for Conv2d.

4 Likes

This is expected, and that’s why there is the output_size argument to ConvTranspose2d, that compensates for it.
As an example, suppose there is a convolution of stride 2 and kernel size 2. When applying it to a 4x4 image, the result is a 2x2 image. Now, if you apply the same convolution to a 5x5 image, the result will also be 2x2, so for a given Conv2d, two inputs map to the same output size.

5 Likes

ConvTranspose2d doesn’t have output_size argument, do you mean output_padding?
oh , I found output_size in source code

1 Like

So how to solve this problem?

There is output_size argument in nn.ConvTranspose2d.forward function. Just provide desired output shape during forward pass.

2 Likes

I find that ‘output_padding is provided to resolve this ambiguity by effectively increasing the calculated output shape on one side. Note that output_padding is only used to find output shape, but does not actually add zero-padding to output.’ in the official ducumentsconvtranspose2d.
So I think we can use ‘output_padding’ parameters to make output size same.

@apaszke would you please answer my above question about the shape of weight tensor in transposed convolution?

It would be great to mention this in the docs of the nn.ConvTranspose2d object

Hello! does anyone know what to do when working with the functional version? f.conv_transpose2d does not have an output_size argument. I need to use the functional version since I am tying weights.

output_size is used to get the right output_padding as seen in this line of code.
However, as you can see, this method is not exposed, so you could just define the output_padding manually or use this hacky way to invoke the method:

conv = nn.ConvTranspose2d(3, 1, 2, 2, bias=False)
x = torch.randn(1, 3, 10, 10)

# vanilla
output_vanilla = conv(x)
print(output.shape)
> torch.Size([1, 1, 20, 20])

# output_size
output_size = conv(x, output_size=(21, 21))
print(output.shape)
> torch.Size([1, 1, 21, 21])

# functional API
weight = conv.weight.detach().clone()
output_func = F.conv_transpose2d(x, weight, stride=2)
print(output_func.shape)
> torch.Size([1, 1, 20, 20])
print((output_func-output_vanilla).abs().max())
> tensor(0., grad_fn=<MaxBackward1>)

# hacky way to get output padding
output_padding = nn.ConvTranspose2d._output_padding(
    self=None,
    input=x,
    output_size=(21, 21),
    stride=(2, 2),
    padding=(0, 0),
    kernel_size=(2, 2)
)

output_func_size = F.conv_transpose2d(
    x, weight, stride=2, output_padding=output_padding)
print(output_func_size.shape)
> torch.Size([1, 1, 21, 21])
print((output_func_size-output_size).abs().max())
> tensor(0., grad_fn=<MaxBackward1>)

I’m not sure, why _output_padding is wrapped in a class method and not exposed publicly, so I would rather recommend to calculate the output_padding argument manually without relying on this hacky approach.

3 Likes

Thanks alot! yeah I’ll start with tweaking the output_padding, but that is a good alternative.

Hi would I able to put ConvTranspose2d in Sequential and adding other layers and writing output_size.
I am working on wgan, where I want to experiment with both different Upsample layer with it,

Like this, I used both layers just to increase image dimensions by 2

def upconv(in_channels, out_channels, kernel_size=1, stride=1, padding=0, batch_norm=True, up=False, trans=True):
    layers = []
    # conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, padding_mode=padding_mode)
    if up:
        layers.append(nn.Upsample(scale_factor=scale_factor, mode=mode, align_corners=align_corners))
    if trans:
      layers.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding))
      
    if batch_norm:
        layers.append(nn.BatchNorm2d(out_channels))
    return nn.Sequential(*layers)

but Upsample works , but for ConvTranspose2d I need to write output_size so that it would just double dimension , but It is not possible

a = torch.randn((1, 32, 16, 16))
upconv(32, 16)(a, output_size=(a.shape[0], a.shape[1], a.shape[2]*2, a.shape[3]*2))

How can i do this? :thinking:

I think the easiest way would be to create a custom module and pass the output_size to it, if it’s static and should be used in an nn.Sequential container as seen here:

class MyConvTranspose2d(nn.Module):
    def __init__(self, conv, output_size):
        super(MyConvTranspose2d, self).__init__()
        self.output_size = output_size
        self.conv = conv
        
    def forward(self, x):
        x = self.conv(x, output_size=self.output_size)
        return x


conv = nn.ConvTranspose2d(1, 1, 2, 2)
x = torch.randn(1, 1, 24, 24)
out = conv(x)
print(out.shape)
> torch.Size([1, 1, 48, 48])

out = conv(x, output_size=(49, 49))
print(out.shape)
> torch.Size([1, 1, 49, 49])

my_conv = MyConvTranspose2d(conv, output_size=(49, 49))
out = my_conv(x)
print(out.shape)
> torch.Size([1, 1, 49, 49])

model = nn.Sequential(MyConvTranspose2d(conv, output_size=(49, 49)))
out = model(x)
print(out.shape)
> torch.Size([1, 1, 49, 49])

If you need to set output_size dynamically, I would recommend to write a custom model and not use nn.Sequential.

3 Likes

Thanks for great idea. It is working. :smiley:

class UpOrTrans(nn.Module):
    def __init__(self,in_channels, out_channels, mode='up', kernel_size=1, stride=1, scale_factor=2,  padding=0, batch_norm=True, **kwargs):
        super(UpOrTrans, self).__init__()
        self.mode = mode
        upmode = kwargs.get('upmode', 'nearest')
        align_corners = kwargs.get('align_corners', None)
        self.batch_norm = batch_norm
        
        if self.mode == 'up':
            self.up = nn.Upsample(scale_factor=scale_factor, mode=upmode, align_corners=align_corners)
            self.conv = nn.Conv2d(in_channels, out_channels,  kernel_size, stride, padding)
        
        if self.mode == 'trans':
            self.trans = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)
        
        if self.batch_norm:
            self.bn = nn.BatchNorm2d(out_channels)
        
    def forward(self, x, output_size=None):
        if self.mode == 'up':
            x = self.up(x)
            x = self.conv(x)
        if self.mode == 'trans':
            x = self.trans(x, output_size=output_size)
        
        if self.batch_norm:
            x = self.bn(x) 
        return x