Convert to Keras Model to Pytorch

Dear all,

I tried to convert my Keras model to Pytorch as below,

# Keras
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Conv2D, Dense, Flatten, Lambda
from tensorflow import keras
import numpy as np

def clip_only_center_and_nus(tensor):
    center_nus = tensor[:, 0:12]
    as_ = tensor[:, 12:]
    center_nus = tf.clip_by_value(center_nus, clip_value_min=0.0, clip_value_max=1.0)
    return tf.concat([center_nus, as_], axis=1)

def smoe_ae_tf(checkpoint_path: object) -> object:
    input_image = tf.keras.Input(shape=(8, 8, 1))

    ae = Conv2D(16, (3, 3), padding='same', activation='relu')(input_image)
    ae = Conv2D(32, (3, 3), padding='same', activation='relu')(ae)
    ae = Conv2D(64, (3, 3), padding='same', activation='relu')(ae)
    ae = Conv2D(128, (3, 3), padding='same', activation='relu')(ae)
    ae = Conv2D(256, (3, 3), padding='same', activation='relu')(ae)
    ae = Conv2D(512, (3, 3), padding='same', activation='relu')(ae)
    ae = Conv2D(1024, (3, 3), padding='same', activation='relu')(ae)
    ae = Flatten()(ae)
    ae = Dense(1024, activation='relu')(ae)
    ae = Dense(512, activation='relu')(ae)
    ae = Dense(256, activation='relu')(ae)
    ae = Dense(128, activation='relu')(ae)
    ae = Dense(64, activation='relu')(ae)
    ae = Dense(28, activation='linear')(ae)

    ae = Lambda(clip_only_center_and_nus)(ae)

    ae_model = keras.Model(inputs=input_image, outputs=ae)
    ae_model.load_weights(checkpoint_path)
    return ae_model

import torch
import torch.nn as nn
import torch.nn.functional as F


def clamp_pb_pu(tensor):
    center_nus = tensor[:, 0:12]
    nu = tensor[:, 12:]
    center_nu = torch.clamp(center_nus, min=0.0, max=1.0)
    return torch.concat([center_nu, nu], axis=1)


class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x): return self.func(x)


class smoe_ae(nn.Module):
    def __init__(self, in_features=16, out_features=32, kernel_num=3, stride=1, padding="same"):
        super(smoe_ae,self).__init__()
      
        self.conv1 = nn.Conv2d(1, in_features, kernel_size=kernel_num, stride=stride, padding=padding, dtype=torch.float32)
        self.conv2 = nn.Conv2d(in_features, out_features, kernel_size=kernel_num, stride=stride, padding=padding, dtype=torch.float32)
        self.conv3 = nn.Conv2d(in_features*2, out_features*2, kernel_size=kernel_num, stride=stride, padding=padding, dtype=torch.float32)
        self.conv4 = nn.Conv2d(in_features*4, out_features*4, kernel_size=kernel_num, stride=stride, padding=padding, dtype=torch.float32)
        self.conv5 = nn.Conv2d(in_features*8, out_features*8, kernel_size=kernel_num, stride=stride, padding=padding, dtype=torch.float32)
        self.conv6 = nn.Conv2d(in_features*16, out_features*16, kernel_size=kernel_num, stride=stride, padding=padding, dtype=torch.float32)
        self.conv7 = nn.Conv2d(in_features*32, out_features*32, kernel_size=kernel_num, stride=stride, padding=padding, dtype=torch.float32)
      
        self.flatten = nn.Flatten(0,2)

        self.fc1 = nn.Linear(in_features**4, out_features*32, dtype=torch.float32)
        self.fc2 = nn.Linear(in_features*64, out_features*16, dtype=torch.float32) 
        self.fc3 = nn.Linear(in_features*32, out_features*8, dtype=torch.float32)
        self.fc4 = nn.Linear(in_features*16, out_features*4, dtype=torch.float32)
        self.fc5 = nn.Linear(in_features*8, out_features*2, dtype=torch.float32)
        self.fc6 = nn.Linear(in_features*4, out_features-4, dtype=torch.float32)
         

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = F.relu(self.conv7(x))
      

        x = self.flatten(x)[None,:]

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = F.relu(self.fc5(x))
        x = self.fc6(x)

        x = Lambda(clamp_pb_pu)(x)

        return x

Then passed the Keras model weights to the Pytorch model as follows,

checkpoint_path = './path/cp.ckpt''
model_keras = smoe_ae_tf(checkpoint_path)
model_torch = smoe_ae()

model_torch.conv1.weight.data = torch.from_numpy(np.transpose(model_keras[0]))
model_torch.conv1.bias.data = torch.from_numpy(model_keras[1])

model_torch.conv2.weight.data = torch.from_numpy(np.transpose(model_keras[2]))
model_torch.conv2.bias.data = torch.from_numpy(model_keras[3])

model_torch.conv3.weight.data = torch.from_numpy(np.transpose(model_keras[4]))
model_torch.conv3.bias.data = torch.from_numpy(model_keras[5])

model_torch.conv4.weight.data = torch.from_numpy(np.transpose(model_keras[6]))
model_torch.conv4.bias.data = torch.from_numpy(model_keras[7])

model_torch.conv5.weight.data = torch.from_numpy(np.transpose(model_keras[8]))
model_torch.conv5.bias.data = torch.from_numpy(model_keras[9])

model_torch.conv6.weight.data = torch.from_numpy(np.transpose(model_keras[10]))
model_torch.conv6.bias.data = torch.from_numpy(model_keras[11])

model_torch.conv7.weight.data = torch.from_numpy(np.transpose(model_keras[12]))
model_torch.conv7.bias.data = torch.from_numpy(model_keras[13])

model_torch.fc1.weight.data = torch.from_numpy(np.transpose(model_keras[14]))
model_torch.fc1.bias.data = torch.from_numpy(model_keras[15])

model_torch.fc2.weight.data = torch.from_numpy(np.transpose(model_keras[16]))
model_torch.fc2.bias.data = torch.from_numpy(model_keras[17])

model_torch.fc3.weight.data = torch.from_numpy(np.transpose(model_keras[18]))
model_torch.fc3.bias.data = torch.from_numpy(model_keras[19])

model_torch.fc4.weight.data = torch.from_numpy(np.transpose(model_keras[20]))
model_torch.fc4.bias.data = torch.from_numpy(model_keras[21])

model_torch.fc5.weight.data = torch.from_numpy(np.transpose(model_keras[22]))
model_torch.fc5.bias.data = torch.from_numpy(model_keras[23])

model_torch.fc6.weight.data = torch.from_numpy(np.transpose(model_keras[24]))
model_torch.fc6.bias.data = torch.from_numpy(model_keras[25])

Nonetheless, both models did not produce the same result for the same input.

batch_size, C, H, W = 1, 1, 8, 8
x = torch.randn(C, H, W, dtype=torch.float32)

y = model_torch(x).detach().numpy()
y1 = model_keras(x.detach().numpy()).numpy()


print(f'tensorflow autoencoder output {y1.shape} and PyTorch autoencoder output {y.shape}')

print('=======================================================')

print(f'tensorflow autoencoder output \n {y1} and PyTorch autoencoder output \n {y}')

Unfortunately, I could not figure out what was the problem with transferring the model to PyTorch or weights.

Any help or advice will be appreciated.