Copy weights of tf.nn.depthwise_conv2d to F.conv2d

The main difference between TensorFlow and PyTorch is the filter size:
In TensorFlow, the size of the filter is (kernel height, kernel width, input_channel, output_channel)
while in PyTorch the size is (output_channel, input_channel/groups, kernel height, kernel width)
To this end, we need to transform the size of tensorflow filter to pytorch filter.
The origin code is as follows.

 def image_derivs(x, nc):
     dy = tf.nn.depthwise_conv2d(x, tf.tile(tf.expand_dims(tf.expand_dims([[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]], 2), 3), [1, 1, nc, 1]), strides=[1, 1, 1, 1], padding='VALID')
     dx = tf.nn.depthwise_conv2d(x, tf.tile(tf.expand_dims(tf.expand_dims([[1.0, 0.0, -1.0], [2.0, 0.0, -2.0], [1.0, 0.0, -1.0]], 2), 3), [1, 1, nc, 1]), strides=[1, 1, 1, 1], padding='VALID')
     return dy, dx

The corresponding pytorch code:

class image_derivs(nn.Module):
    def __init__(self, nc):
        super(image_derivs, self).__init__()
        """
        a=torch.unsqueeze(torch.Tensor([[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]]), 2)    
        b=torch.unsqueeze(a,3)
        c=b.repeat(1, 1, nc, 1)
        c is a filter in the function tf.nn.depthwise_conv(). 
        The expanded version of [[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]] is a weight of the filter.
        The shape of c is (3,3,64,1), which denotes (kernel heights, kernel width, input channels, output multiplier)
        """
        self.nc=nc
        a_y=torch.unsqueeze(torch.Tensor([[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]]), 2)    
        b_y=torch.unsqueeze(a_y,3)
        c_y=b_y.repeat(1, 1, nc, 1)
        self.weight_y = nn.Parameter(c_y.permute(2, 3, 0, 1))

        a_x=torch.unsqueeze(torch.Tensor([[1.0, 0.0, -1.0], [2.0, 0.0, -2.0], [1.0, 0.0, -1.0]]), 2)    
        b_x=torch.unsqueeze(a_x,3)
        c_x=b_x.repeat(1, 1, nc, 1)
        self.weight_x = nn.Parameter(c_x.permute(2, 3, 0, 1))

    def forward(self, x):
        d_x = F.conv2d(x, self.weight_x, stride=1, padding=0, groups=self.nc)
        d_y = F.conv2d(x, self.weight_y, stride=1, padding=0, groups=self.nc)
        return d_y, d_x

The test code:

def test_depthwise():
   model=image_derivs(64)
   x=torch.randint(0,10,(2,64,192,192)).type(torch.float32)
   d_y, d_x=model(x)

   torch_dy=tf.convert_to_tensor(d_y.permute(0,2,3,1).detach().numpy())
   torch_dx=tf.convert_to_tensor(d_x.permute(0,2,3,1).detach().numpy())
   x=x.permute(0,2,3,1)
   x=x.numpy()
   x=tf.convert_to_tensor(x)

   nc=64
   dy = tf.nn.depthwise_conv2d(x, tf.tile(tf.expand_dims(tf.expand_dims([[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]], 2), 3), [1, 1, nc, 1]), strides=[1, 1, 1, 1], padding='VALID')
   dx = tf.nn.depthwise_conv2d(x, tf.tile(tf.expand_dims(tf.expand_dims([[1.0, 0.0, -1.0], [2.0, 0.0, -2.0], [1.0, 0.0, -1.0]], 2), 3), [1, 1, nc, 1]), strides=[1, 1, 1, 1], padding='VALID')
   sess=tf.Session()
   result=sess.run(dy)

   print((sess.run(torch_dy)==result).all())