How to convert keras model to Pytorch, and run inference in C++ correctly?

Hi,

Due to the current project’s speed requirement,
I’m trying to convert my keras model to use in Pytorch and run inference in C++.
(I’ve used Tensorflow in C++, but the speed cannot meet the goal.
I’ve seen some articles saying that pytorch perform well, so I’d like to give it a try)

To complete the process, the workflow I’ve done is like:

  1. Rewrite a model structure in Pytorch
  2. Load keras’s model weight and copy to the Pytorch one
  3. Save model to .pt
  4. Run inference in C++

Here’s the details I’ve done through the whole process:
*** 1.Rewrite a model structure in Pytorch

  • The original model structure with keras:
from keras.layers import Input, Dropout, Conv2DTranspose
from keras.models import Model
from keras.layers import Conv2D, MaxPooling2D, concatenate


def small_unet(input_size = (256,256,1), label_CHANNELS = 2):
    inputs = Input(input_size)

    conv1 = Conv2D(8, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(8, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = Conv2D(16, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(16, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = Conv2D(32, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv4 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    conv5 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv5)
    up6 = concatenate([Conv2DTranspose(64, kernel_size=(2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
    conv6 = Conv2D(64, (3, 3), activation='relu', padding='same')(up6)
    conv6 = Dropout(0.5)(conv6)
    conv6 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv6)
    up7 = concatenate([Conv2DTranspose(32, kernel_size=(2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
    conv7 = Conv2D(32, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Dropout(0.5)(conv7)
    conv7 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv7)
    up8 = concatenate([Conv2DTranspose(32, kernel_size=(2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
    conv8 = Conv2D(16, (3, 3), activation='relu', padding='same')(up8)
    conv8 = Dropout(0.5)(conv8)
    conv8 = Conv2D(16, (3, 3), activation='relu', padding='same')(conv8)
    up9 = concatenate([Conv2DTranspose(16, kernel_size=(2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3) 
    conv9 = Conv2D(8, (3, 3), activation='relu', padding='same')(up9)
    conv9 = Dropout(0.2)(conv9)
    conv9 = Conv2D(8, (3, 3), activation='relu', padding='same')(conv9)
    conv10 = Conv2D(label_CHANNELS, (1, 1), activation='softmax')(conv9)
    model = Model(inputs=inputs, outputs=conv10)

    return model


if __name__ == '__main__':
    model = small_unet(input_size = (1920,1920,1), label_CHANNELS = 4)
  • The rewrite model with Pytorch:
def add_conv_stage(dim_in, dim_out, kernel_size=(3,3), stride=1, padding=1):
    return nn.Sequential(
      nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, stride=stride, padding=padding),
      nn.ReLU(),
      nn.Conv2d(dim_out, dim_out, kernel_size=kernel_size, stride=stride, padding=padding),
      nn.ReLU()
    )

class Small_unet(nn.Module):
    def __init__(self, LABLE_CHANNELS=2):
        super(Small_unet, self).__init__()
        
        self.conv1 = add_conv_stage(1, 8)
        self.pool1 = nn.MaxPool2d((2, 2))
        self.conv2 = add_conv_stage(8, 16)
        self.pool2 = nn.MaxPool2d((2, 2))
        self.conv3 = add_conv_stage(16, 32)
        self.pool3 = nn.MaxPool2d((2, 2))
        self.conv4 = add_conv_stage(32, 64)
        self.pool4 = nn.MaxPool2d((2, 2))
        self.conv5 = add_conv_stage(64, 128)
        self.up6 = nn.ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
        self.conv6 = add_conv_stage(128, 64)
        self.up7 = nn.ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))
        self.conv7 = add_conv_stage(64, 32)
        self.up8 = nn.ConvTranspose2d(32, 32, kernel_size=(2, 2), stride=(2, 2))
        self.conv8 = add_conv_stage(48, 16)
        self.up9 = nn.ConvTranspose2d(16, 16, kernel_size=(2, 2), stride=(2, 2))
        self.conv9 = add_conv_stage(24, 8)
        self.conv10 = nn.Conv2d(8, LABLE_CHANNELS, kernel_size=(1,1))
    
    def forward(self, x):
        conv1 = self.conv1(x)
        pool1 = self.pool1(conv1)
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        conv3 = self.conv3(pool2)
        pool3 = self.pool3(conv3)
        conv4 = self.conv4(pool3)
        pool4 = self.pool4(conv4)
        conv5 = self.conv5(pool4)
        up6 = torch.cat([self.up6(conv5), conv4], 1)
        conv6 = self.conv6(up6)
        up7 = torch.cat([self.up7(conv6), conv3], 1)
        conv7 = self.conv7(up7)
        up8 = torch.cat([self.up8(conv7), conv2], 1)
        conv8 = self.conv8(up8)
        up9 = torch.cat([self.up9(conv8), conv1], 1)
        conv9 = self.conv9(up9)
        conv10 = self.conv10(conv9)
        out = F.softmax(conv10, dim=1)
        return out

*** 2.Load keras’s model weight and copy to the Pytorch one

model = load_model('keras_model.h5')
weights=model.get_weights()

# copy weight from keras to pytorch
pt_model.conv1[0].weight.data = torch.from_numpy(np.transpose(weights[0]))
pt_model.conv1[2].weight.data = torch.from_numpy(np.transpose(weights[2]))

pt_model.conv2[0].weight.data = torch.from_numpy(np.transpose(weights[4]))
pt_model.conv2[2].weight.data = torch.from_numpy(np.transpose(weights[6]))

pt_model.conv3[0].weight.data = torch.from_numpy(np.transpose(weights[8]))
pt_model.conv3[2].weight.data = torch.from_numpy(np.transpose(weights[10]))

pt_model.conv4[0].weight.data = torch.from_numpy(np.transpose(weights[12]))
pt_model.conv4[2].weight.data = torch.from_numpy(np.transpose(weights[14]))

pt_model.conv5[0].weight.data = torch.from_numpy(np.transpose(weights[16]))
pt_model.conv5[2].weight.data = torch.from_numpy(np.transpose(weights[18]))

pt_model.up6.weight.data = torch.from_numpy(np.transpose(weights[20]))
pt_model.conv6[0].weight.data = torch.from_numpy(np.transpose(weights[22]))
pt_model.conv6[2].weight.data = torch.from_numpy(np.transpose(weights[24]))

pt_model.up7.weight.data = torch.from_numpy(np.transpose(weights[26]))
pt_model.conv7[0].weight.data = torch.from_numpy(np.transpose(weights[28]))
pt_model.conv7[2].weight.data = torch.from_numpy(np.transpose(weights[30]))

pt_model.up8.weight.data = torch.from_numpy(np.transpose(weights[32]))
pt_model.conv8[0].weight.data = torch.from_numpy(np.transpose(weights[34]))
pt_model.conv8[2].weight.data = torch.from_numpy(np.transpose(weights[36]))

pt_model.up9.weight.data = torch.from_numpy(np.transpose(weights[38]))
pt_model.conv9[0].weight.data = torch.from_numpy(np.transpose(weights[40]))
pt_model.conv9[2].weight.data = torch.from_numpy(np.transpose(weights[42]))

pt_model.conv10.weight.data = torch.from_numpy(np.transpose(weights[44]))

*** 3.Save model to .pt (torchscript)

  • The code I use to save to .pt file:
# save model
inp = torch.rand(1, 1, 1920, 1920)
traced_script_module = torch.jit.trace(pt_model, inp)
traced_script_module.save("trial_with_softmax.pt")

*** 4.Run inference in C++

  • So far, I can run the inference process with this code:
#define COMPILER_MSVC
#include <iostream>
#include <torch/script.h>
#include <memory>

#include "opencv2/opencv.hpp"
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp> 
#include <opencv2/core/core.hpp>       

#define kIMAGE_SIZE 1920
#define kCHANNELS 1
#define outCHANNELS 4

int main()
{
	torch::jit::script::Module module = torch::jit::load("trial_with_softmax.pt");
	module.to(at::kCUDA);

	//read input image
	cv::Mat inp_image = cv::imread("012.png", 0);
	inp_image.convertTo(inp_image, CV_32FC1, 1.0f / 255.0f);
	auto input_tensor = torch::from_blob(inp_image.data, { 1, kIMAGE_SIZE, kIMAGE_SIZE, kCHANNELS });
	input_tensor = input_tensor.permute({ 0, 3, 1, 2 });

	// to GPU
	input_tensor = input_tensor.to(at::kCUDA);

	torch::Tensor out_tensor = module.forward({ input_tensor }).toTensor();
	auto out_tens = out_tensor.to(torch::kFloat32);

	out_tensor = out_tensor.squeeze().detach().permute({ 1, 2, 0 });
	out_tensor = out_tensor.mul(255).clamp(0, 255).to(torch::kU8);
	out_tensor = out_tensor.to(torch::kCPU);

	std::clock_t post = std::clock();
	cv::Mat result(kIMAGE_SIZE, kIMAGE_SIZE, CV_8UC(outCHANNELS));
	cv::Mat ch0, ch1, ch2, ch3;
	std::memcpy((void *)result.data, out_tensor.data_ptr(), sizeof(torch::kU8) * out_tensor.numel());
	cv::extractChannel(result, ch0, 0);
	cv::extractChannel(result, ch1, 1);
	cv::extractChannel(result, ch2, 2);
	cv::extractChannel(result, ch3, 3);

	cv::imwrite("ch0.png", ch0);
	cv::imwrite("ch1.png", ch1);
	cv::imwrite("ch2.png", ch2);
	cv::imwrite("ch3.png", ch3);
}

However, the result is different from the one I got from my keras model…
This is my first time using Pytorch, I’m not very familiar with it.
I’m not sure is it the way of handling input/output data doing wrong,
or the way I convert the weight?
Maybe providing model will help, here is my keras model:
keras_model.h5

Thanks in advance for any help!

1 Like

Have you copied the biases as well, like pt_model.conv1[0].bias.data

1 Like

Hi,

Thanks for your reply!
Sorry I don’t know I have to copy biases too… :cry:
I’ve completed this part,
and here’s the code of copying biases in my case:

#copy bias from keras to pytorch
pt_model.conv1[0].bias.data = torch.from_numpy(np.transpose(weights[1]))
pt_model.conv1[2].bias.data = torch.from_numpy(np.transpose(weights[3]))

pt_model.conv2[0].bias.data = torch.from_numpy(np.transpose(weights[5]))
pt_model.conv2[2].bias.data = torch.from_numpy(np.transpose(weights[7]))

pt_model.conv3[0].bias.data = torch.from_numpy(np.transpose(weights[9]))
pt_model.conv3[2].bias.data = torch.from_numpy(np.transpose(weights[11]))

pt_model.conv4[0].bias.data = torch.from_numpy(np.transpose(weights[13]))
pt_model.conv4[2].bias.data = torch.from_numpy(np.transpose(weights[15]))

pt_model.conv5[0].bias.data = torch.from_numpy(np.transpose(weights[17]))
pt_model.conv5[2].bias.data = torch.from_numpy(np.transpose(weights[19]))

pt_model.up6.bias.data = torch.from_numpy(np.transpose(weights[21]))
pt_model.conv6[0].bias.data = torch.from_numpy(np.transpose(weights[23]))
pt_model.conv6[2].bias.data = torch.from_numpy(np.transpose(weights[25]))

pt_model.up7.bias.data = torch.from_numpy(np.transpose(weights[27]))
pt_model.conv7[0].bias.data = torch.from_numpy(np.transpose(weights[29]))
pt_model.conv7[2].bias.data = torch.from_numpy(np.transpose(weights[31]))

pt_model.up8.bias.data = torch.from_numpy(np.transpose(weights[33]))
pt_model.conv8[0].bias.data = torch.from_numpy(np.transpose(weights[35]))
pt_model.conv8[2].bias.data = torch.from_numpy(np.transpose(weights[37]))

pt_model.up9.bias.data = torch.from_numpy(np.transpose(weights[39]))
pt_model.conv9[0].bias.data = torch.from_numpy(np.transpose(weights[41]))
pt_model.conv9[2].bias.data = torch.from_numpy(np.transpose(weights[43]))

pt_model.conv10.bias.data = torch.from_numpy(weights[45])

The result seems more alike to the keras one
(Although there’s still a bit difference between them,
but it’s way better than the one I’ve done before)

By the way,
Do you have any idea whether there’s any redundant part in my C++ inference code?
I’m not sure if this is the most optimized code to run the inference process,
but I don’t know where can I revise for that… :frowning_face:

Hi there,

I don’t understand how the above code can work. Keras convolution weights are stored as

[Kx, Ky, C_in, C_out]

while Pytorch Conv2d weights are stored as

[C_out, C_in, Kx, Ky]

And a simple transposition does not take you from one to the other. See documentation below.
Could you please clarify if you really got the code above to work, or if you needed to implement a more complicated permutation between Keras and Pytorch tensors? Thanks!

https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html