The problem is that reshape() doesn’t really do what you think it is doing
and is messing up the two-dimensional structure of the image.
Use permute() on the pytorch side and transpose() on the numpy side.
Here is a slightly tweaked version of your code with the fix for reshape()
at the end:
>>> import keras
2022-07-28 23:46:08.455782: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'cudart64_110.dll'; dlerror: cudart64_110.dll not found
2022-07-28 23:46:08.455958: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
>>> keras.__version__
'2.9.0'
>>>
>>> import torch
>>> torch.__version__
'1.13.0.dev20220727'
>>>
>>> import numpy as np
>>> np.__version__
'1.23.1'
>>>
>>> from keras.layers import Conv2D
>>> from torch import nn
>>>
>>> keras.utils.set_random_seed (2022)
>>> np.random.seed (2022)
>>>
>>> img = np.random.rand(1, 256, 256, 1)
>>>
>>> ## TF Init
>>> conv_tf = Conv2D(
... 64, 3, activation="relu", padding="same", kernel_initializer="he_normal",
... )
>>>
>>> _ = conv_tf(img)
2022-07-28 23:46:17.830862: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'cudart64_110.dll'; dlerror: cudart64_110.dll not found
2022-07-28 23:46:17.831780: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'cublas64_11.dll'; dlerror: cublas64_11.dll not found
2022-07-28 23:46:17.832975: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'cublasLt64_11.dll'; dlerror: cublasLt64_11.dll not found
2022-07-28 23:46:18.957824: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'cusolver64_11.dll'; dlerror: cusolver64_11.dll not found
2022-07-28 23:46:18.958813: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'cusparse64_11.dll'; dlerror: cusparse64_11.dll not found
2022-07-28 23:46:18.959940: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'cudnn64_8.dll'; dlerror: cudnn64_8.dll not found
2022-07-28 23:46:18.960028: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1850] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
2022-07-28 23:46:18.960743: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
>>> conv_tf.bias = np.random.rand(64)
>>>
>>> ## PT Init + copy weights
>>> conv_torch = nn.Conv2d(
... in_channels=1, out_channels=64, kernel_size=3, padding="same", bias=True
... )
>>> conv_torch.weight = nn.parameter.Parameter(
... torch.Tensor(conv_tf.weights[0].numpy().transpose(3, 2, 0, 1))
... )
>>> conv_torch.bias = nn.parameter.Parameter(torch.Tensor(conv_tf.bias))
>>>
>>> conv_torch = nn.Sequential(
... conv_torch,
... nn.ReLU()
... )
>>>
>>> pred_tf = conv_tf(img).numpy()
>>> pred_pt = conv_torch(torch.Tensor(img).reshape(1, 1, 256, 256)).detach().numpy().reshape(pred_tf.shape)
>>> pred_pt_fix = conv_torch(torch.Tensor(img).permute (0, 3, 1, 2)).detach().numpy().transpose (0, 2, 3, 1)
>>>
>>> np.abs (pred_tf - pred_pt).max() # reshape ruins the two-dimensional structure of the image
4.251324
>>> np.abs (pred_tf - pred_pt_fix).max() # using permute / transpose fixes the problem
9.536743e-07