I went ahead and decided to improve on @MicPie’s solution little bit. Specifically, it now accepts 4-tuple in form of ((pad_up, pad_bottom), (pad_left, pad_right))
for padding to allow for asymmetric padding. The transposed convolution function supports dilation as well as output padding. Also added additional functions to compute padding for desired output shape.
Below is full code. Note that currently PytTorch Conv2D layer accepts only 2-tuple so it only allows symmetric padding. To use asymmetric padding, currently you must additional padding layer either usin F.pad
or ZeroPadding layer. See here for more discussions.
import math
def num2tuple(num):
return num if isinstance(num, tuple) else (num, num)
def conv2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):
h_w, kernel_size, stride, pad, dilation = num2tuple(h_w), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation)
pad = num2tuple(pad[0]), num2tuple(pad[1])
h = math.floor((h_w[0] + sum(pad[0]) - dilation[0]*(kernel_size[0]-1) - 1) / stride[0] + 1)
w = math.floor((h_w[1] + sum(pad[1]) - dilation[1]*(kernel_size[1]-1) - 1) / stride[1] + 1)
return h, w
def convtransp2d_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1, out_pad=0):
h_w, kernel_size, stride, pad, dilation, out_pad = num2tuple(h_w), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(pad), num2tuple(dilation), num2tuple(out_pad)
pad = num2tuple(pad[0]), num2tuple(pad[1])
h = (h_w[0] - 1)*stride[0] - sum(pad[0]) + dialation[0]*(kernel_size[0]-1) + out_pad[0] + 1
w = (h_w[1] - 1)*stride[1] - sum(pad[1]) + dialation[1]*(kernel_size[1]-1) + out_pad[1] + 1
return h, w
def conv2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1):
h_w_in, h_w_out, kernel_size, stride, dilation = num2tuple(h_w_in), num2tuple(h_w_out), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation)
p_h = ((h_w_out[0] - 1)*stride[0] - h_w_in[0] + dilation[0]*(kernel_size[0]-1) + 1)
p_w = ((h_w_out[1] - 1)*stride[1] - h_w_in[1] + dilation[1]*(kernel_size[1]-1) + 1)
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))
def convtransp2d_get_padding(h_w_in, h_w_out, kernel_size=1, stride=1, dilation=1, out_pad=0):
h_w_in, h_w_out, kernel_size, stride, dilation, out_pad = num2tuple(h_w_in), num2tuple(h_w_out), \
num2tuple(kernel_size), num2tuple(stride), num2tuple(dilation), num2tuple(out_pad)
p_h = -(h_w_out[0] - 1 - out_pad[0] - dialation[0]*(kernel_size[0]-1) - (h_w[0] - 1)*stride[0]) / 2
p_w = -(h_w_out[1] - 1 - out_pad[1] - dialation[1]*(kernel_size[1]-1) - (h_w[1] - 1)*stride[1]) / 2
return (math.floor(p_h/2), math.ceil(p_h/2)), (math.floor(p_w/2), math.ceil(p_w/2))