Hi,
I am trying to implement a single 2D Convolutional layer alone in both PyTorch and TF to get the same result. But the outputs across two frameworks are not matching.
Please note that I’m pretty new to Pytorch framework.
Below is my code:
import tensorflow as tf
import torch
import torch.nn as nn
import numpy as np
# Set a random seed for reproducibility
np.random.seed(0)
tf.random.set_seed(0)
torch.manual_seed(0)
# Define input data
batch_size = 1
height = 28
width = 28
channels = 3
# Create a random 4-dimensional input tensor
input_data = np.random.rand(batch_size, height, width, channels)
# Convert input_data to PyTorch tensor with the correct channel order
input_tensor_torch = torch.tensor(input_data.transpose(0, 3, 1, 2), dtype=torch.float32)
# Convert input_data to TensorFlow tensor
input_tensor_tf = tf.convert_to_tensor(input_data, dtype=tf.float32)
# Define a simple convolutional network using PyTorch
class SimpleTorchNet(nn.Module):
def __init__(self):
super(SimpleTorchNet, self).__init__()
self.conv = nn.Conv2d(in_channels=channels, out_channels=32, kernel_size=3, padding=1)
self.relu = nn.ReLU() # Use ReLU activation
# Initialize weights with Xavier/Glorot initialization
nn.init.xavier_uniform_(self.conv.weight)
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x
# Define a simple convolutional network using TensorFlow
def simple_tf_net(input_shape):
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=input_shape, kernel_initializer='glorot_uniform'),
tf.keras.layers.ReLU() # Use ReLU activation
])
return model
# Create instances of both networks
torch_model = SimpleTorchNet()
tf_model = simple_tf_net(input_tensor_tf.shape[1:])
# Set both models to evaluation mode
torch_model.eval()
tf_model.compile()
# Forward pass through both models
torch_conv_output = torch_model.conv(input_tensor_torch)
tf_conv_output = tf_model.layers[0](input_tensor_tf)
# Convert the intermediate convolution outputs to NumPy arrays
torch_conv_output = torch_conv_output.detach().numpy()
tf_conv_output = tf_conv_output.numpy()
torch_conv_output = torch_conv_output.transpose(0,3,2,1)
# Check if the convolution outputs are almost equal
if np.allclose(torch_conv_output, tf_conv_output, rtol=1e-3, atol=1e-5):
print("Convolution outputs are almost equal.")
else:
print("Convolution outputs are not equal.")
# Print the convolution outputs for comparison
print("PyTorch Convolution Output:")
print(torch_conv_output)
print("TensorFlow Convolution Output:")
print(tf_conv_output)
And the outputs are:
Convolution outputs are not equal.
PyTorch Convolution Output:
[[[[ 2.48323381e-01 -1.44315615e-01 6.65397942e-02 … -5.15214056e-02
2.30577439e-02 -8.47472250e-03]
[ 2.21085072e-01 -1.24474078e-01 -2.88453009e-02 … -8.46047699e-03
7.81251043e-02 -1.45667195e-01]
[ 8.48063156e-02 -2.21538171e-02 1.40678465e-01 … -1.14889488e-01
2.38565207e-01 -1.21196359e-01]
…
[ 3.04039448e-01 -2.21818641e-01 -1.35927826e-01 … 8.10652971e-03
7.33793527e-02 -5.52261621e-02]
[ 2.63487756e-01 -2.04650789e-01 8.13340023e-03 … 1.38982907e-01
1.85968667e-01 -7.56865293e-02]
[ 2.19897881e-01 8.58511776e-03 -7.49853551e-02 … -6.54965043e-02
1.35003895e-01 9.49371085e-02]]
[[ 4.00723338e-01 -3.03725362e-01 1.09511361e-01 … 4.80477512e-03
1.05379745e-01 -1.27631158e-01]
[ 2.89355904e-01 -2.28242949e-01 -7.92173110e-03 … -8.79125446e-02
2.39318043e-01 2.62786567e-01]
[ 4.30388212e-01 -1.87168211e-01 1.65860169e-02 … 1.55370116e-01
1.54494122e-01 3.19483250e-01]
…
[ 3.05706799e-01 -2.78193891e-01 5.70965409e-02 … -1.21588096e-01
2.03510404e-01 1.66158199e-01]
[ 3.67006540e-01 -1.19644925e-01 2.02343352e-02 … 1.85276851e-01
1.19460210e-01 8.14109817e-02]
[ 8.13112631e-02 -2.03564644e-01 -1.71513155e-01 … 1.19606949e-01
1.67114452e-01 1.25381038e-01]]
[[ 4.85455275e-01 -3.55471492e-01 1.87599361e-02 … 8.85102376e-02
-3.87404859e-03 -9.97015685e-02]
[ 2.31786862e-01 3.58396322e-02 -2.74845194e-02 … 1.70969129e-01
1.58900067e-01 2.34841391e-01]
[ 5.09006202e-01 -2.22626969e-01 -5.72422929e-02 … -1.05161220e-02
2.05749273e-01 6.42738938e-02]
…
[ 2.07579628e-01 2.51404122e-02 -1.41085181e-02 … -2.67385095e-02
2.27404237e-01 1.31224945e-01]
[ 3.91307116e-01 -2.34702170e-01 1.43818542e-01 … -1.29898772e-01
2.08452195e-01 1.46718174e-01]
[ 3.14387083e-01 -2.03645393e-01 -2.29841303e-02 … 1.65475160e-02
9.40885246e-02 1.41478196e-01]]
…
[[ 2.82162994e-01 -1.72286659e-01 8.45219195e-02 … 1.04866259e-01
4.88728434e-02 1.28283054e-01]
[ 4.03141856e-01 -3.28208566e-01 -5.05447499e-02 … 7.14532733e-02
1.03780165e-01 2.24766165e-01]
[ 1.64524198e-01 9.13195908e-02 -2.92014241e-01 … 1.24898188e-01
1.76963031e-01 3.82717699e-02]
…
[ 4.79387224e-01 -1.97556108e-01 9.14509296e-02 … 2.72627696e-02
2.64826298e-01 5.59813417e-02]
[ 2.48289108e-02 -1.05925553e-01 -9.74631589e-03 … -1.01358593e-02
3.25453639e-01 -1.00798473e-01]
[ 2.42129162e-01 -3.08295548e-01 2.44555809e-02 … 6.37313500e-02
-3.64885181e-02 1.17797084e-01]]
[[ 3.55375528e-01 -2.31190592e-01 2.43358806e-01 … 1.93280280e-01
1.32792220e-02 -1.04077935e-01]
[ 3.29133868e-01 -9.87341180e-02 2.56389618e-01 … -6.40408695e-02
1.03478596e-01 1.90610662e-02]
[ 4.40511703e-01 -5.12388527e-01 1.63640246e-01 … 3.37819010e-02
1.00145623e-01 2.00788423e-01]
…
[ 4.28133070e-01 -3.16074491e-01 1.49143070e-01 … -7.19198883e-02
2.81987458e-01 3.78655493e-02]
[ 2.89786816e-01 -6.62525147e-02 4.44806851e-02 … -1.58710435e-01
6.36694282e-02 1.19430952e-01]
[-4.22601402e-03 -7.56801069e-02 -2.67677121e-02 … 1.10379256e-01
8.76079500e-02 8.95961374e-02]]
[[ 4.81417418e-01 -2.44408101e-03 -8.95669311e-03 … -9.69416648e-02
-9.36923921e-02 -1.47169352e-01]
[ 2.65685260e-01 -4.34794948e-02 1.99183762e-01 … -2.18919352e-01
1.03905439e-01 7.55844042e-02]
[ 3.45925063e-01 -5.63386977e-02 8.31451789e-02 … -1.26478925e-01
-6.54388964e-02 7.10651726e-02]
…
[ 1.93428233e-01 1.65429831e-01 -4.80100513e-04 … -5.75009733e-02
-1.05049536e-02 1.01330347e-01]
[ 2.75043190e-01 -1.71770126e-01 2.26597562e-02 … -1.24513611e-01
2.18288392e-01 2.01615721e-01]
[ 7.48421997e-02 -4.15402986e-02 -5.40232658e-03 … -1.73081234e-01
-3.22686210e-02 3.23817670e-01]]]]
TensorFlow Convolution Output:
[[[[-8.64531621e-02 -2.31194407e-01 -2.03320384e-01 … 7.32287839e-02
7.60681704e-02 2.43514553e-01]
[-2.17935741e-01 -1.39851928e-01 -9.45312679e-02 … 2.05472082e-01
2.53116190e-01 5.47657073e-01]
[ 1.95900537e-02 -1.91922367e-01 -4.58249450e-02 … 2.85228968e-01
1.39543518e-01 5.96701622e-01]
…
[-2.24213079e-01 -4.65130210e-02 -8.23171251e-03 … 4.51635212e-01
5.51500916e-02 2.83227235e-01]
[-1.61875978e-01 -3.05988610e-01 -1.43437237e-01 … 1.47720158e-01
5.93833216e-02 3.07234079e-01]
[ 1.02496721e-01 -3.63795578e-01 9.27821696e-02 … 2.78628170e-01
-4.65876833e-02 4.41644520e-01]]
[[-1.48203924e-01 -9.77261513e-02 -1.25923783e-01 … 2.76923686e-01
-2.17240714e-02 1.16366327e-01]
[-9.61510986e-02 1.63530067e-01 3.82284485e-02 … 4.85221654e-01
-9.89363864e-02 3.14035892e-01]
[-1.66874409e-01 6.41696528e-02 -7.04695061e-02 … 3.78075182e-01
-9.31830928e-02 1.46217629e-01]
…
[-2.74854690e-01 2.01384202e-02 -7.11751543e-03 … 3.27884436e-01
1.41117319e-01 3.53036404e-01]
[ 1.00046732e-01 -3.00218523e-01 -7.16970265e-02 … 1.61608502e-01
-1.28880426e-01 4.45543975e-01]
[ 1.49676889e-01 -2.70714551e-01 2.18152389e-01 … 2.42459446e-01
1.13611303e-01 3.71493429e-01]]
[[-8.16111788e-02 -5.05981892e-02 -1.35538921e-01 … 3.19916129e-01
-7.79381245e-02 2.22417414e-01]
[-2.00727344e-01 -9.19563323e-02 8.63034129e-02 … 4.18722898e-01
1.08559318e-01 2.54754782e-01]
[ 7.85998404e-02 -3.31695676e-01 1.11179680e-01 … 2.04992697e-01
-4.00718153e-02 2.42862508e-01]
…
[-2.57928222e-01 1.91073809e-02 2.74826139e-02 … 5.12124836e-01
-1.51765626e-03 1.42035022e-01]
[-1.06569074e-01 -1.24023505e-03 -2.12093666e-02 … 4.49546069e-01
2.47034766e-02 7.04151914e-02]
[ 1.53106347e-01 -1.93653554e-01 1.41600266e-01 … 3.02722186e-01
1.03283919e-01 2.65004575e-01]]
…
[[-2.02031717e-01 -1.59965798e-01 -2.33964667e-01 … 2.97624767e-01
3.25715244e-02 9.20782015e-02]
[-1.80526733e-01 4.57346030e-02 -1.35388553e-01 … 4.76938993e-01
-1.67653576e-01 5.25587082e-01]
[-2.11385009e-03 -3.81369352e-01 5.16906828e-02 … 4.11862999e-01
-5.21451570e-02 3.46830368e-01]
…
[-7.70914555e-02 -2.37294152e-01 1.58647392e-02 … 2.77898133e-01
-3.06099821e-02 5.01858354e-01]
[ 3.16166729e-02 -2.44096190e-01 1.75569430e-01 … 4.17368114e-01
-1.05751961e-01 5.40087759e-01]
[-5.11724909e-04 -2.55994588e-01 2.82634318e-01 … 4.00341719e-01
8.26932043e-02 3.20087224e-01]]
[[-2.49359027e-01 1.41489711e-02 -2.21551627e-01 … 1.81067407e-01
-3.34437564e-02 1.98972091e-01]
[-1.86305836e-01 5.34134693e-02 -1.27478316e-01 … 4.86616492e-01
1.40836567e-01 1.88983619e-01]
[ 1.40557543e-03 -1.13023140e-01 2.80143898e-02 … 4.44250762e-01
3.26635465e-02 2.81470537e-01]
…
[-8.15126896e-02 -2.91436650e-02 -6.73124641e-02 … 3.99910122e-01
-2.67125983e-02 1.67134881e-01]
[ 1.55402899e-01 -1.09187931e-01 1.35599911e-01 … 5.37904739e-01
1.76364295e-02 2.70472318e-01]
[ 2.24355862e-01 -1.93946451e-01 4.06818718e-01 … 1.40395373e-01
2.09924534e-01 1.82779536e-01]]
[[ 5.25706969e-02 1.63330406e-01 -7.44978935e-02 … 2.37995103e-01
-1.28021747e-01 2.02119738e-01]
[-1.23365037e-01 -1.67426784e-02 1.25069275e-01 … 1.54897049e-01
-4.81138229e-02 5.93677815e-03]
[ 3.19783315e-02 1.09110363e-02 5.54072931e-02 … 1.92608282e-01
3.80366296e-02 1.56384438e-01]
…
[ 7.22453147e-02 -6.69171661e-02 7.90397450e-02 … 4.93446104e-02
1.85365558e-01 7.20729828e-02]
[-3.74199525e-02 1.66067109e-01 1.49424940e-01 … 3.02835584e-01
-2.57160496e-02 1.29516935e-02]
[ 2.46175855e-01 4.64651883e-02 1.30395383e-01 … 3.21524382e-01
-1.14177186e-02 -1.26671225e-01]]]]