Difference In the output image when using traced model(.pt) with C++ OpenCV

Hello,
I have retrained the model based on EnlightenGAN. Further I have traced the model in order to execute it in a C++ application using libTorch v1.6. However, I am getting slightly different results as compared to the python(executing the traced model) version.

The model requires the input RGB tensor the attention map Image tensor as input. The attention map is basically to inform the model about the image region which requires contrast enhancement.

Below is the code to get inference the output from PT model in python.


def getTransform():
    transform_list = []
    transform_list += [transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5),
                                            (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)

def convertToCV(tensor):
    
    tensor = torch.squeeze(tensor)
    tensor = tensor.cpu().float().detach()
    tensor = torch.unsqueeze(tensor, 0)
    tensor = tensor.permute(1, 2, 0)
    tensor = ((tensor +1)/2.0) * 255.0
    tensor = tensor.numpy()
    return tensor

def proprocess(image):
    
    transform = getTransform()
    trgbImage = transform(image)
    r,g,b = trgbImage[0]+1, trgbImage[1]+1, trgbImage[2]+1
    tattentionImage = 1. - ((0.299*r+0.587*g+0.114*b)/2.)
    tattentionImage = torch.unsqueeze(tattentionImage, 0)
    trgbImage = torch.unsqueeze(trgbImage, 0)
    tattentionImage = torch.unsqueeze(tattentionImage, 0)

    return trgbImage, tattentionImage

def run(inputPath, OutputPath):

    modelToLoad = torch.jit.load("./EGAN.pt")
    print("OK")
    count =0 
    for filename in os.listdir(inputPath):
        
        print("Processing Image : ", filename)
        inputImage = cv2.imread(os.path.join(inputPath,filename))
        
        rgbImage, attentionImage = proprocess(inputImage)
        
        fake, real = modelToLoad.forward(rgbImage,attentionImage )
        
        fake_B = convertToCV(fake)
        fake_B1 = cv2.cvtColor(fake_B, cv2.COLOR_RGB2BGR)
        cv2.imwrite(OutputPath + "pic1.png" , fake_B )

The C++ version for the inference code is below

#define A 0.299
#define B 0.5870
#define C 0.114

cv::Mat torchTensortoCVMat(torch::Tensor& tensor)
{
	tensor = tensor.squeeze(0);
	tensor = tensor.to(torch::kCPU).to(torch::kFloat32).detach();
	tensor = tensor.permute({ 1, 2, 0 }).contiguous();
	tensor = tensor.add(1).div(2.0).mul(255.0);
	tensor = tensor.to(torch::kU8);

	int64_t height = tensor.size(0);
	int64_t width  = tensor.size(1);
	cv::Mat mat    = cv::Mat(cv::Size(width, height), CV_8UC3, tensor.data_ptr<uchar>());
	return mat;
}

std::vector<torch::jit::IValue> CV2Tensor(const cv::Mat& cv_Image)
{
	torch::Tensor tInputImage = (torch::from_blob(cv_Image.data, { cv_Image.rows, cv_Image.cols, cv_Image.channels() }, torch::kByte));
	tInputImage = tInputImage.to(torch::kFloat).div(255);
	tInputImage = tInputImage.sub(0.5).div(0.5).permute({ 2, 0, 1 });

	torch::Tensor red   = tInputImage[0] + 1 ;
	torch::Tensor green = tInputImage[1] + 1 ;
	torch::Tensor blue  = tInputImage[2] + 1;

	red   = red.mul(A);
	green = green.mul(B);
	blue  = blue.mul(C);

	torch::Tensor channelSum = red.add(green).add(blue);
	channelSum = channelSum.div(2.);
	torch::Tensor tGrayImage = 1. - channelSum;

	tGrayImage.unsqueeze_(0);
	tGrayImage.unsqueeze_(0);
	tInputImage.unsqueeze_(0);

	std::vector<torch::jit::IValue> input;
	input.push_back(tInputImage);
	input.push_back(tGrayImage);

	return input;
}

void enhanceImage(const std::string& Img, torch::jit::script::Module& network,const std::string& outputPath, std::string& fileName)
{
	cv::Mat rgbImage;
	cv::Mat inputImage = cv::imread(Img);
	std::vector<torch::jit::IValue> input = CV2Tensor(inputImage);

	try
	{
		auto outputs = network.forward(input).toTuple();
		torch::Tensor resultFake = outputs->elements()[0].toTensor();
	
		cv::Mat output1 = torchTensortoCVMat(resultFake);
		cv::imshow("out1.png", output1);
		cv::waitKey(0);
	}
	catch (std::exception& e)
	{
		std::cout << e.what() << std::endl;
	}
}	

I have also checked the tensor output at all the steps, and they are same. However, after the conversion the output image has color flowing out from brighter regions of input image as show below.

           Python Version                         C++ Version

I have tried many attempt but I am totally puzzeled as to how I should solve the problem. Any help is most welcome.

Thanks.
PS : Let me know if more info is required.

So Iā€™m not sure I understand the code in detail, but are you doing the BGR2RGB for the decoding? PIL/ToTensor gives RGB and opencv has BGR by default.

Best regards

Thomas

Hi, no I do not perform BGR2RGB for decoding. The reason is that I do not use PIL image in Python/C++ version. I am using the Opencv Mat and performing the Tensor operation.

PS: Should I post more details (such as includes) in the code?

Best Regards
Apurv

Ah, right, the torchvision transform in your snippet had me confused. How do you check c++ and Python produce the same output? I would probably stuck the c++ code into an extension to check each step.

Hi,
I found the problem in the above implementation. In the CPP version, I was not clamping the values after doing denormalisation. I have put a clamping function and now it is working as expected.
The edit part if anyone stumbles on the same problem is below:

tensor = tensor.mul(0.5).add(0.5).mul(255.0); -- > tensor = tensor.mul(0.5).add(0.5).mul(255.0).clamp(0, 255);

Without clamping it as causing overflow in brighter regions of the image.

Is there a better way of comparing Tensor output from Python and C++?
Currently I was checking the correctness like a caveman :sweat_smile:
Checking via printing it out.

Best regards
Apurv

1 Like

Check out the cpp extension tutorial for calling your cpp code from Python. I find that quite convenient.
I doalot of printing, though, too.

1 Like