Pytorch conv_transpose2d and TF conv2d_trasnpose

I am converting a TF weight to Pytorch.
I copied the weight for ConvTranspose but it produces the different results.

  1. To simulate padding=‘same’ in TF, Pytorch needs to zero-pad by 1x1 in some cases. Should it be before F.conv_transpose2d or after F.conv_transpose2d?

  2. Why is the snippet below produces the reversed result?

  3. Is there any extra step for Pytorch to match the behavior of TF?

import tensorflow as tf
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

dim = 4
inp = np.ones((1, dim, dim, 3))
filter = np.ones((4, 4, 3, 3))

# Pytorch [in, out, kernel_size[0] (h), kernel_size[1] (w)] 3,3,4,4
# TF      [filter_height, filter_width, out_channels, in_channels] 4,4,3,3

x = tf.placeholder(shape=(1, dim, dim, 3), dtype=tf.float64)
# x = tf.layers.conv2d_transpose(x, 3, kernel_size=4, strides=1, padding='SAME')
y = tf.nn.conv2d_transpose(x, filter, output_shape=(1, dim, dim, 3), strides=(1, 1, 1, 1))

with tf.Session() as sess:
    y = sess.run([y], feed_dict={x: inp})[0]

    y = np.transpose(y, (0, 3, 1, 2))
    print(y.shape)
    print('sum', y.sum())
    print(y)


###################

inp = torch.tensor(np.ones((1, 3, dim, dim)))
weight = torch.tensor(np.ones((3, 3, 4, 4)))

inp = F.pad(inp, [0, 1, 0, 1], mode='constant', value=0)
y = F.conv_transpose2d(inp, weight, bias=None, padding=2, output_padding=0)
y = y.numpy()
print(y.shape)
print('sum', y.sum())
print(y)

sum 1296.0
[[[[12. 18. 24. 18.]
   [18. 27. 36. 27.]
   [24. 36. 48. 36.]
   [18. 27. 36. 27.]]
  [[12. 18. 24. 18.]
   [18. 27. 36. 27.]
   [24. 36. 48. 36.]
   [18. 27. 36. 27.]]
  [[12. 18. 24. 18.]
   [18. 27. 36. 27.]
   [24. 36. 48. 36.]
   [18. 27. 36. 27.]]]]
(1, 3, 4, 4)
sum 1296.0
[[[[27. 36. 27. 18.]
   [36. 48. 36. 24.]
   [27. 36. 27. 18.]
   [18. 24. 18. 12.]]
  [[27. 36. 27. 18.]
   [36. 48. 36. 24.]
   [27. 36. 27. 18.]
   [18. 24. 18. 12.]]
  [[27. 36. 27. 18.]
   [36. 48. 36. 24.]
   [27. 36. 27. 18.]
   [18. 24. 18. 12.]]]]

Didn’t this one work?

  inp = F.pad(inp, [1, 0, 1, 0], mode='constant', value=0)