Pytorch equivalent of Tensorflow 2d convolution layer

Hi,

I am trying to implement a single 2D Convolutional layer alone in both PyTorch and TF to get the same result. But the outputs across two frameworks are not matching.

Please note that I’m pretty new to Pytorch framework.

Below is my code:

import tensorflow as tf
import torch
import torch.nn as nn
import numpy as np

# Set a random seed for reproducibility
np.random.seed(0)
tf.random.set_seed(0)
torch.manual_seed(0)

# Define input data
batch_size = 1
height = 28
width = 28
channels = 3

# Create a random 4-dimensional input tensor
input_data = np.random.rand(batch_size, height, width, channels)

# Convert input_data to PyTorch tensor with the correct channel order
input_tensor_torch = torch.tensor(input_data.transpose(0, 3, 1, 2), dtype=torch.float32)

# Convert input_data to TensorFlow tensor
input_tensor_tf = tf.convert_to_tensor(input_data, dtype=tf.float32)

# Define a simple convolutional network using PyTorch
class SimpleTorchNet(nn.Module):
    def __init__(self):
        super(SimpleTorchNet, self).__init__()
        self.conv = nn.Conv2d(in_channels=channels, out_channels=32, kernel_size=3, padding=1)
        self.relu = nn.ReLU()  # Use ReLU activation

        # Initialize weights with Xavier/Glorot initialization
        nn.init.xavier_uniform_(self.conv.weight)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x

# Define a simple convolutional network using TensorFlow
def simple_tf_net(input_shape):
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=input_shape, kernel_initializer='glorot_uniform'),
        tf.keras.layers.ReLU()  # Use ReLU activation
    ])
    return model

# Create instances of both networks
torch_model = SimpleTorchNet()
tf_model = simple_tf_net(input_tensor_tf.shape[1:])

# Set both models to evaluation mode
torch_model.eval()
tf_model.compile()

# Forward pass through both models
torch_conv_output = torch_model.conv(input_tensor_torch)
tf_conv_output = tf_model.layers[0](input_tensor_tf)

# Convert the intermediate convolution outputs to NumPy arrays
torch_conv_output = torch_conv_output.detach().numpy()
tf_conv_output = tf_conv_output.numpy()

torch_conv_output = torch_conv_output.transpose(0,3,2,1)

# Check if the convolution outputs are almost equal
if np.allclose(torch_conv_output, tf_conv_output, rtol=1e-3, atol=1e-5):
    print("Convolution outputs are almost equal.")
else:
    print("Convolution outputs are not equal.")

# Print the convolution outputs for comparison
print("PyTorch Convolution Output:")
print(torch_conv_output)

print("TensorFlow Convolution Output:")
print(tf_conv_output)

And the outputs are:

Convolution outputs are not equal.
PyTorch Convolution Output:
[[[[ 2.48323381e-01 -1.44315615e-01 6.65397942e-02 … -5.15214056e-02
2.30577439e-02 -8.47472250e-03]
[ 2.21085072e-01 -1.24474078e-01 -2.88453009e-02 … -8.46047699e-03
7.81251043e-02 -1.45667195e-01]
[ 8.48063156e-02 -2.21538171e-02 1.40678465e-01 … -1.14889488e-01
2.38565207e-01 -1.21196359e-01]

[ 3.04039448e-01 -2.21818641e-01 -1.35927826e-01 … 8.10652971e-03
7.33793527e-02 -5.52261621e-02]
[ 2.63487756e-01 -2.04650789e-01 8.13340023e-03 … 1.38982907e-01
1.85968667e-01 -7.56865293e-02]
[ 2.19897881e-01 8.58511776e-03 -7.49853551e-02 … -6.54965043e-02
1.35003895e-01 9.49371085e-02]]

[[ 4.00723338e-01 -3.03725362e-01 1.09511361e-01 … 4.80477512e-03
1.05379745e-01 -1.27631158e-01]
[ 2.89355904e-01 -2.28242949e-01 -7.92173110e-03 … -8.79125446e-02
2.39318043e-01 2.62786567e-01]
[ 4.30388212e-01 -1.87168211e-01 1.65860169e-02 … 1.55370116e-01
1.54494122e-01 3.19483250e-01]

[ 3.05706799e-01 -2.78193891e-01 5.70965409e-02 … -1.21588096e-01
2.03510404e-01 1.66158199e-01]
[ 3.67006540e-01 -1.19644925e-01 2.02343352e-02 … 1.85276851e-01
1.19460210e-01 8.14109817e-02]
[ 8.13112631e-02 -2.03564644e-01 -1.71513155e-01 … 1.19606949e-01
1.67114452e-01 1.25381038e-01]]

[[ 4.85455275e-01 -3.55471492e-01 1.87599361e-02 … 8.85102376e-02
-3.87404859e-03 -9.97015685e-02]
[ 2.31786862e-01 3.58396322e-02 -2.74845194e-02 … 1.70969129e-01
1.58900067e-01 2.34841391e-01]
[ 5.09006202e-01 -2.22626969e-01 -5.72422929e-02 … -1.05161220e-02
2.05749273e-01 6.42738938e-02]

[ 2.07579628e-01 2.51404122e-02 -1.41085181e-02 … -2.67385095e-02
2.27404237e-01 1.31224945e-01]
[ 3.91307116e-01 -2.34702170e-01 1.43818542e-01 … -1.29898772e-01
2.08452195e-01 1.46718174e-01]
[ 3.14387083e-01 -2.03645393e-01 -2.29841303e-02 … 1.65475160e-02
9.40885246e-02 1.41478196e-01]]

[[ 2.82162994e-01 -1.72286659e-01 8.45219195e-02 … 1.04866259e-01
4.88728434e-02 1.28283054e-01]
[ 4.03141856e-01 -3.28208566e-01 -5.05447499e-02 … 7.14532733e-02
1.03780165e-01 2.24766165e-01]
[ 1.64524198e-01 9.13195908e-02 -2.92014241e-01 … 1.24898188e-01
1.76963031e-01 3.82717699e-02]

[ 4.79387224e-01 -1.97556108e-01 9.14509296e-02 … 2.72627696e-02
2.64826298e-01 5.59813417e-02]
[ 2.48289108e-02 -1.05925553e-01 -9.74631589e-03 … -1.01358593e-02
3.25453639e-01 -1.00798473e-01]
[ 2.42129162e-01 -3.08295548e-01 2.44555809e-02 … 6.37313500e-02
-3.64885181e-02 1.17797084e-01]]

[[ 3.55375528e-01 -2.31190592e-01 2.43358806e-01 … 1.93280280e-01
1.32792220e-02 -1.04077935e-01]
[ 3.29133868e-01 -9.87341180e-02 2.56389618e-01 … -6.40408695e-02
1.03478596e-01 1.90610662e-02]
[ 4.40511703e-01 -5.12388527e-01 1.63640246e-01 … 3.37819010e-02
1.00145623e-01 2.00788423e-01]

[ 4.28133070e-01 -3.16074491e-01 1.49143070e-01 … -7.19198883e-02
2.81987458e-01 3.78655493e-02]
[ 2.89786816e-01 -6.62525147e-02 4.44806851e-02 … -1.58710435e-01
6.36694282e-02 1.19430952e-01]
[-4.22601402e-03 -7.56801069e-02 -2.67677121e-02 … 1.10379256e-01
8.76079500e-02 8.95961374e-02]]

[[ 4.81417418e-01 -2.44408101e-03 -8.95669311e-03 … -9.69416648e-02
-9.36923921e-02 -1.47169352e-01]
[ 2.65685260e-01 -4.34794948e-02 1.99183762e-01 … -2.18919352e-01
1.03905439e-01 7.55844042e-02]
[ 3.45925063e-01 -5.63386977e-02 8.31451789e-02 … -1.26478925e-01
-6.54388964e-02 7.10651726e-02]

[ 1.93428233e-01 1.65429831e-01 -4.80100513e-04 … -5.75009733e-02
-1.05049536e-02 1.01330347e-01]
[ 2.75043190e-01 -1.71770126e-01 2.26597562e-02 … -1.24513611e-01
2.18288392e-01 2.01615721e-01]
[ 7.48421997e-02 -4.15402986e-02 -5.40232658e-03 … -1.73081234e-01
-3.22686210e-02 3.23817670e-01]]]]

TensorFlow Convolution Output:

[[[[-8.64531621e-02 -2.31194407e-01 -2.03320384e-01 … 7.32287839e-02
7.60681704e-02 2.43514553e-01]
[-2.17935741e-01 -1.39851928e-01 -9.45312679e-02 … 2.05472082e-01
2.53116190e-01 5.47657073e-01]
[ 1.95900537e-02 -1.91922367e-01 -4.58249450e-02 … 2.85228968e-01
1.39543518e-01 5.96701622e-01]

[-2.24213079e-01 -4.65130210e-02 -8.23171251e-03 … 4.51635212e-01
5.51500916e-02 2.83227235e-01]
[-1.61875978e-01 -3.05988610e-01 -1.43437237e-01 … 1.47720158e-01
5.93833216e-02 3.07234079e-01]
[ 1.02496721e-01 -3.63795578e-01 9.27821696e-02 … 2.78628170e-01
-4.65876833e-02 4.41644520e-01]]

[[-1.48203924e-01 -9.77261513e-02 -1.25923783e-01 … 2.76923686e-01
-2.17240714e-02 1.16366327e-01]
[-9.61510986e-02 1.63530067e-01 3.82284485e-02 … 4.85221654e-01
-9.89363864e-02 3.14035892e-01]
[-1.66874409e-01 6.41696528e-02 -7.04695061e-02 … 3.78075182e-01
-9.31830928e-02 1.46217629e-01]

[-2.74854690e-01 2.01384202e-02 -7.11751543e-03 … 3.27884436e-01
1.41117319e-01 3.53036404e-01]
[ 1.00046732e-01 -3.00218523e-01 -7.16970265e-02 … 1.61608502e-01
-1.28880426e-01 4.45543975e-01]
[ 1.49676889e-01 -2.70714551e-01 2.18152389e-01 … 2.42459446e-01
1.13611303e-01 3.71493429e-01]]

[[-8.16111788e-02 -5.05981892e-02 -1.35538921e-01 … 3.19916129e-01
-7.79381245e-02 2.22417414e-01]
[-2.00727344e-01 -9.19563323e-02 8.63034129e-02 … 4.18722898e-01
1.08559318e-01 2.54754782e-01]
[ 7.85998404e-02 -3.31695676e-01 1.11179680e-01 … 2.04992697e-01
-4.00718153e-02 2.42862508e-01]

[-2.57928222e-01 1.91073809e-02 2.74826139e-02 … 5.12124836e-01
-1.51765626e-03 1.42035022e-01]
[-1.06569074e-01 -1.24023505e-03 -2.12093666e-02 … 4.49546069e-01
2.47034766e-02 7.04151914e-02]
[ 1.53106347e-01 -1.93653554e-01 1.41600266e-01 … 3.02722186e-01
1.03283919e-01 2.65004575e-01]]

[[-2.02031717e-01 -1.59965798e-01 -2.33964667e-01 … 2.97624767e-01
3.25715244e-02 9.20782015e-02]
[-1.80526733e-01 4.57346030e-02 -1.35388553e-01 … 4.76938993e-01
-1.67653576e-01 5.25587082e-01]
[-2.11385009e-03 -3.81369352e-01 5.16906828e-02 … 4.11862999e-01
-5.21451570e-02 3.46830368e-01]

[-7.70914555e-02 -2.37294152e-01 1.58647392e-02 … 2.77898133e-01
-3.06099821e-02 5.01858354e-01]
[ 3.16166729e-02 -2.44096190e-01 1.75569430e-01 … 4.17368114e-01
-1.05751961e-01 5.40087759e-01]
[-5.11724909e-04 -2.55994588e-01 2.82634318e-01 … 4.00341719e-01
8.26932043e-02 3.20087224e-01]]

[[-2.49359027e-01 1.41489711e-02 -2.21551627e-01 … 1.81067407e-01
-3.34437564e-02 1.98972091e-01]
[-1.86305836e-01 5.34134693e-02 -1.27478316e-01 … 4.86616492e-01
1.40836567e-01 1.88983619e-01]
[ 1.40557543e-03 -1.13023140e-01 2.80143898e-02 … 4.44250762e-01
3.26635465e-02 2.81470537e-01]

[-8.15126896e-02 -2.91436650e-02 -6.73124641e-02 … 3.99910122e-01
-2.67125983e-02 1.67134881e-01]
[ 1.55402899e-01 -1.09187931e-01 1.35599911e-01 … 5.37904739e-01
1.76364295e-02 2.70472318e-01]
[ 2.24355862e-01 -1.93946451e-01 4.06818718e-01 … 1.40395373e-01
2.09924534e-01 1.82779536e-01]]

[[ 5.25706969e-02 1.63330406e-01 -7.44978935e-02 … 2.37995103e-01
-1.28021747e-01 2.02119738e-01]
[-1.23365037e-01 -1.67426784e-02 1.25069275e-01 … 1.54897049e-01
-4.81138229e-02 5.93677815e-03]
[ 3.19783315e-02 1.09110363e-02 5.54072931e-02 … 1.92608282e-01
3.80366296e-02 1.56384438e-01]

[ 7.22453147e-02 -6.69171661e-02 7.90397450e-02 … 4.93446104e-02
1.85365558e-01 7.20729828e-02]
[-3.74199525e-02 1.66067109e-01 1.49424940e-01 … 3.02835584e-01
-2.57160496e-02 1.29516935e-02]
[ 2.46175855e-01 4.64651883e-02 1.30395383e-01 … 3.21524382e-01
-1.14177186e-02 -1.26671225e-01]]]]

You are randomly initializing the layers in both frameworks without loading their parameters. The random number generation is not the same between all applications and thus the mismatches are expected.

Hi,

I have updated my code now i.e., set weights and bias parameters in both the frameworks. But even then, i’m not able to get match in convolution outputs.

import tensorflow as tf
import torch
import torch.nn as nn
import numpy as np

# Set a random seed for reproducibility
np.random.seed(0)
tf.random.set_seed(0)
torch.manual_seed(0)

# Define input data
batch_size = 1
height = 28
width = 28
channels = 3

# Create a random 4-dimensional input tensor
input_data = np.random.rand(batch_size, height, width, channels)

# Convert input_data to PyTorch tensor with the correct channel order
input_tensor_torch = torch.tensor(input_data.transpose(0, 3, 1, 2), dtype=torch.float32)

# Convert input_data to TensorFlow tensor
input_tensor_tf = tf.convert_to_tensor(input_data, dtype=tf.float32)

# Define a simple convolutional network using PyTorch
def initialize_weights(m):
  if isinstance(m, nn.Conv2d):
      nn.init.ones_(m.weight.data)
      if m.bias is not None:
          nn.init.zeros_(m.bias.data)

class SimpleTorchNet(nn.Module):
    def __init__(self):
        super(SimpleTorchNet, self).__init__()
        self.conv = nn.Conv2d(in_channels=channels, out_channels=32, kernel_size=3, padding=1)
        self.relu = nn.ReLU()  # Use ReLU activation

        # Initialize weights with Xavier/Glorot initialization
        nn.init.xavier_uniform_(self.conv.weight)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x

# Define a simple convolutional network using TensorFlow
def simple_tf_net(input_shape):
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=input_shape, kernel_initializer=tf.keras.initializers.Ones(), bias_initializer='zeros'),
        tf.keras.layers.ReLU()  # Use ReLU activation
    ])
    return model

# Create instances of both networks
torch_model = SimpleTorchNet()
torch_model.apply(initialize_weights)  # set weights and bias parameters

tf_model = simple_tf_net(input_tensor_tf.shape[1:])

# Set both models to evaluation mode
torch_model.eval()
tf_model.compile()

# Forward pass through both models
torch_conv_output = torch_model.conv(input_tensor_torch)
tf_conv_output = tf_model.layers[0](input_tensor_tf)

# Convert the intermediate convolution outputs to NumPy arrays
torch_conv_output = torch_conv_output.detach().numpy()
tf_conv_output = tf_conv_output.numpy()

torch_conv_output = torch_conv_output.transpose(0,3,2,1)

# Check if the convolution outputs are almost equal
if np.allclose(torch_conv_output, tf_conv_output, rtol=1e-3, atol=1e-5):
    print("Convolution outputs are almost equal.")
else:
    print("Convolution outputs are not equal.")

# Print the convolution outputs for comparison
print("PyTorch Convolution Output:")
print(torch_conv_output)

print("TensorFlow Convolution Output:")
print(tf_conv_output)

These are the outputs:

Convolution outputs are not equal.
PyTorch Convolution Output:
[[[[ 6.4356203 6.4356203 6.4356203 … 6.4356203 6.4356203
6.4356203]
[ 8.131134 8.131134 8.131134 … 8.131134 8.131134
8.131134 ]
[ 8.287735 8.287735 8.287735 … 8.287735 8.287735
8.287735 ]

[ 9.558674 9.558674 9.558674 … 9.558674 9.558674
9.558674 ]
[ 9.299343 9.299343 9.299343 … 9.299343 9.299343
9.299343 ]
[ 5.831039 5.831039 5.831039 … 5.831039 5.831039
5.831039 ]]

[[ 9.84642 9.84642 9.84642 … 9.84642 9.84642
9.84642 ]
[12.904908 12.904908 12.904908 … 12.904908 12.904908
12.904908 ]
[12.263399 12.263399 12.263399 … 12.263399 12.263399
12.263399 ]

[15.957125 15.957125 15.957125 … 15.957125 15.957125
15.957125 ]
[15.244741 15.244741 15.244741 … 15.244741 15.244741
15.244741 ]
[ 9.655096 9.655096 9.655096 … 9.655096 9.655096
9.655096 ]]

[[ 9.517401 9.517401 9.517401 … 9.517401 9.517401
9.517401 ]
[12.817567 12.817567 12.817567 … 12.817567 12.817567
12.817567 ]
[10.967307 10.967307 10.967307 … 10.967307 10.967307
10.967307 ]

[14.38706 14.38706 14.38706 … 14.38706 14.38706
14.38706 ]
[15.969399 15.969399 15.969399 … 15.969399 15.969399
15.969399 ]
[11.06389 11.06389 11.06389 … 11.06389 11.06389
11.06389 ]]

[[ 9.063613 9.063613 9.063613 … 9.063613 9.063613
9.063613 ]
[12.700915 12.700915 12.700915 … 12.700915 12.700915
12.700915 ]
[12.992161 12.992161 12.992161 … 12.992161 12.992161
12.992161 ]

[14.384545 14.384545 14.384545 … 14.384545 14.384545
14.384545 ]
[14.374221 14.374221 14.374221 … 14.374221 14.374221
14.374221 ]
[ 8.6435795 8.6435795 8.6435795 … 8.6435795 8.6435795
8.6435795]]

[[ 7.357518 7.357518 7.357518 … 7.357518 7.357518
7.357518 ]
[11.280155 11.280155 11.280155 … 11.280155 11.280155
11.280155 ]
[13.5729475 13.5729475 13.5729475 … 13.5729475 13.5729475
13.5729475]

[14.412306 14.412306 14.412306 … 14.412306 14.412306
14.412306 ]
[13.724648 13.724648 13.724648 … 13.724648 13.724648
13.724648 ]
[ 8.467352 8.467352 8.467352 … 8.467352 8.467352
8.467352 ]]

[[ 5.6951885 5.6951885 5.6951885 … 5.6951885 5.6951885
5.6951885]
[ 8.372838 8.372838 8.372838 … 8.372838 8.372838
8.372838 ]
[ 9.74875 9.74875 9.74875 … 9.74875 9.74875
9.74875 ]

[ 9.826389 9.826389 9.826389 … 9.826389 9.826389
9.826389 ]
[ 8.265122 8.265122 8.265122 … 8.265122 8.265122
8.265122 ]
[ 4.6325374 4.6325374 4.6325374 … 4.6325374 4.6325374
4.6325374]]]]

TensorFlow Convolution Output:

[[[[ 6.4356203 6.4356203 6.4356203 … 6.4356203 6.4356203
6.4356203]
[ 9.84642 9.84642 9.84642 … 9.84642 9.84642
9.84642 ]
[ 9.517401 9.517401 9.517401 … 9.517401 9.517401
9.517401 ]

[ 9.063613 9.063613 9.063613 … 9.063613 9.063613
9.063613 ]
[ 7.357518 7.357518 7.357518 … 7.357518 7.357518
7.357518 ]
[ 5.6951885 5.6951885 5.6951885 … 5.6951885 5.6951885
5.6951885]]

[[ 8.131134 8.131134 8.131134 … 8.131134 8.131134
8.131134 ]
[12.904908 12.904908 12.904908 … 12.904908 12.904908
12.904908 ]
[12.817567 12.817567 12.817567 … 12.817567 12.817567
12.817567 ]

[12.700915 12.700915 12.700915 … 12.700915 12.700915
12.700915 ]
[11.280155 11.280155 11.280155 … 11.280155 11.280155
11.280155 ]
[ 8.372838 8.372838 8.372838 … 8.372838 8.372838
8.372838 ]]

[[ 8.287735 8.287735 8.287735 … 8.287735 8.287735
8.287735 ]
[12.263399 12.263399 12.263399 … 12.263399 12.263399
12.263399 ]
[10.967307 10.967307 10.967307 … 10.967307 10.967307
10.967307 ]

[12.992161 12.992161 12.992161 … 12.992161 12.992161
12.992161 ]
[13.5729475 13.5729475 13.5729475 … 13.5729475 13.5729475
13.5729475]
[ 9.74875 9.74875 9.74875 … 9.74875 9.74875
9.74875 ]]

[[ 9.558674 9.558674 9.558674 … 9.558674 9.558674
9.558674 ]
[15.957125 15.957125 15.957125 … 15.957125 15.957125
15.957125 ]
[14.38706 14.38706 14.38706 … 14.38706 14.38706
14.38706 ]

[14.384545 14.384545 14.384545 … 14.384545 14.384545
14.384545 ]
[14.412306 14.412306 14.412306 … 14.412306 14.412306
14.412306 ]
[ 9.826389 9.826389 9.826389 … 9.826389 9.826389
9.826389 ]]

[[ 9.299343 9.299343 9.299343 … 9.299343 9.299343
9.299343 ]
[15.244741 15.244741 15.244741 … 15.244741 15.244741
15.244741 ]
[15.969399 15.969399 15.969399 … 15.969399 15.969399
15.969399 ]

[14.374221 14.374221 14.374221 … 14.374221 14.374221
14.374221 ]
[13.724648 13.724648 13.724648 … 13.724648 13.724648
13.724648 ]
[ 8.265122 8.265122 8.265122 … 8.265122 8.265122
8.265122 ]]

[[ 5.831039 5.831039 5.831039 … 5.831039 5.831039
5.831039 ]
[ 9.655096 9.655096 9.655096 … 9.655096 9.655096
9.655096 ]
[11.06389 11.06389 11.06389 … 11.06389 11.06389
11.06389 ]

[ 8.64358 8.64358 8.64358 … 8.64358 8.64358
8.64358 ]
[ 8.46735 8.46735 8.46735 … 8.46735 8.46735
8.46735 ]
[ 4.632537 4.632537 4.632537 … 4.632537 4.632537
4.632537 ]]]]

The output shows almost the same values but in a wrong order so I would revisit the transposes and make sure you are reordering the data correctly. E.g. .transpose(0, 3, 2, 1) looks wrong.