I’ve tried to convert DPSR to c++ version with torch.jit but result were different. python result looks very better than c++.
this is python code for convert:
import os.path
import glob
import logging
import cv2
import numpy as np
from datetime import datetime
from collections import OrderedDict
from scipy.io import loadmat
import torch
import torchvision
from utils import utils_deblur
from utils import utils_logger
from utils import utils_image as util
from models.network_srresnet import SRResNet
from torchviz import make_dot, make_dot_from_trace
import time
def main():
sf = 4 # from 2, 3 and 4
# noise_level_img = 14./255. # noise level of low-quality image
noise_level_img = 7. / 255. # noise level of low-quality image
testsets = 'testsets'
testset_current = 'real_imgs'
use_srganplus = True # 'True' for SRGAN+ (x4) and 'False' for SRResNet+ (x2,x3,x4)
ims = ['5.jpg'] # frog.png
noise_level_model = noise_level_img # noise level of model
if use_srganplus and sf == 4:
model_prefix = 'DPSRGAN'
save_suffix = 'srganplus'
else:
model_prefix = 'DPSR'
save_suffix = 'srresnet'
model_path = os.path.join('DPSR_models', model_prefix+'x%01d.pth' % (sf))
show_img = True
n_channels = 3 # only color images, fixed
# ================================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SRResNet(in_nc=4, out_nc=3, nc=96, nb=16, upscale=sf, act_mode='R', upsample_mode='pixelshuffle')
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
L_folder = os.path.join(testsets, testset_current, 'LR') # L: Low quality
E_folder = os.path.join(testsets, testset_current, 'x{:01d}_'.format(sf)+save_suffix)
util.mkdir(E_folder)
logger.info(L_folder)
and 'kernel' not in im:
for im in ims:
img_name, ext = os.path.splitext(im)
img = util.imread_uint(os.path.join(L_folder, im), n_channels=n_channels)
h, w = img.shape[:2]
#util.imshow(img, title='Low-resolution image') if show_img else None
img = util.unit2single(img)
img_L = util.single2tensor4(img)
tic = time.clock()
noise_level_map = torch.ones((1, 1, img_L.size(2), img_L.size(3)), dtype=torch.float).mul_(noise_level_model)
img_L = torch.cat((img_L, noise_level_map), dim=1)
img_L = img_L.to(device)
with torch.no_grad():
traced_script_module = torch.jit.trace(model, img_L)
with torch.onnx.set_training(model, False):
trace, _ = torch.jit.get_trace_graph(model, args=(img_L,))
dot = make_dot_from_trace(trace)
dot.format = 'svg'
dot.render('./model')
img_E = model(img_L)
img_E = util.tensor2single(img_E)
toc = time.clock()
print('elapsed time = ' + str(toc - tic))
img_E = util.single2uint(img_E[:h*sf, :w*sf]) # np.uint8((z[:h*sf, :w*sf] * 255.0).round())
traced_script_module.save("model_srrestnetPluse.pt")
if __name__ == '__main__':
main()
and c++ side code:
cv::Mat Deep_plug_play_SR::do_srrestnetpluse(cv::Mat input)
{
double noise_level_img = 7 / 255.0;
_img = input.clone();
// convert to float and normalized in range (0,1)
// same as uint2single
cv::cvtColor(_img, _img, cv::COLOR_BGR2RGB);
_img.convertTo(_img, CV_32F);
_img /= 255;
//
cv::Mat z = _img.clone();
torch::Tensor prediction;
at::TensorOptions opts = at::TensorOptions();
opts.dtype(torch::kFloat32);
torch::Tensor sigmaTensor = torch::from_blob(&noise_level_img, { 1 }, opts);
torch::Tensor img_L = torch::from_blob(z.data, { z.rows, z.cols,
z.channels()}, opts).permute({ 2, 0, 1 }).unsqueeze(0);
torch::Tensor noise_level_map = torch::ones({1, 1, z.rows, z.cols },opts).mul_(sigmaTensor);
torch::Tensor netInput = torch::cat({ img_L, noise_level_map }, 1);
netInput = netInput.to(at::kCUDA);
_inputs.push_back(netInput);
prediction = _net->forward(_inputs).toTensor();
z = Tensor_to_Mat(prediction);
return z;
}
cv::Mat Deep_plug_play_SR::Tensor_to_Mat(torch::Tensor in)
{
in = in.squeeze().clamp(0, 1).permute({ 1,2,0 });
//std::cout << in.sizes() << std::endl;
in = in.mul(255.0).to(torch::kU8);
in = in.to(torch::kCPU);
cv::Mat img_result(in.size(0), in.size(1), CV_8UC3);
std::memcpy((void*)img_result.data, in.data_ptr(), sizeof(torch::kU8)*in.numel());
cv::cvtColor(img_result, img_result, cv::COLOR_RGB2BGR);
return img_result;
}
but python result is :
and c++ is:
I cant understand what is my mistake? or its bug in c++ torch interface or even torch.jit
cant trace network well.