Dear All,
I am trying to run the main_test_real_application.py file from this GIT repo: GitHub - cszn/USRNet: Deep Unfolding Network for Image Super-Resolution (CVPR, 2020) (PyTorch). With a low resoltion image as input, the script runs without any problem.
The problem I am encountering is that when I input a relatively big image and set scale factor to 4, my GPU(0) runs out of memory.
I have two Tesla K80 and I have tried to use DataParallel in order to make use of the whole memory available on my server but I couldn’t manage to get it work.
My question is: could you show me how to properly impelement DataParallel on this script?
Thanks a lot
Alex
import os.path
import cv2
import logging
import numpy as np
from datetime import datetime
from collections import OrderedDict
from scipy.io import loadmat
from scipy import ndimage
import scipy.io as scio
import torch
from utils import utils_deblur
from utils import utils_logger
from utils import utils_sisr as sr
from utils import utils_image as util
from models.network_usrnet import USRNet as net
def main():
# ----------------------------------------
# Preparation
# ----------------------------------------
model_name = 'usrgan' # 'usrgan' | 'usrnet' | 'usrgan_tiny' | 'usrnet_tiny'
testset_name = 'set_real' # test set, 'set_real'
test_image = 'image1.jpg' # 'chip.png', 'comic.png'
#test_image = 'comic.png'
sf = 4 # scale factor, only from {1, 2, 3, 4}
show_img = False # default: False
save_E = True # save estimated image
save_LE = False # save zoomed LR, Estimated images
noise_level_img = 2 # noise level for LR image, 0.5~3 for clean images
kernel_width_default_x1234 = [0.4, 0.7, 1.5, 2.0] # default Gaussian kernel widths of clean/sharp images for x1, x2, x3, x4
noise_level_model = noise_level_img/255. # noise level of model
kernel_width = kernel_width_default_x1234[sf-1]
k = utils_deblur.fspecial('gaussian', 25, kernel_width)
k = sr.shift_pixel(k, sf) # shift the kernel
k /= np.sum(k)
util.surf(k) if show_img else None
kernel = util.single2tensor4(k[..., np.newaxis])
n_channels = 1 if 'gray' in model_name else 3 # 3 for color image, 1 for grayscale image
model_pool = 'model_zoo' # fixed
testsets = 'testsets' # fixed
results = 'results' # fixed
result_name = testset_name + '_' + model_name
model_path = os.path.join(model_pool, model_name+'.pth')
# ----------------------------------------
# L_path, E_path
# ----------------------------------------
L_path = os.path.join(testsets, testset_name) # L_path, fixed, for Low-quality images
E_path = os.path.join(results, result_name) # E_path, fixed, for Estimated images
util.mkdir(E_path)
logger_name = result_name
utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
logger = logging.getLogger(logger_name)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# ----------------------------------------
# load model
# ----------------------------------------
if 'tiny' in model_name:
model = net(n_iter=6, h_nc=32, in_nc=4, out_nc=3, nc=[16, 32, 64, 64],
nb=2, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose")
else:
model = net(n_iter=8, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512],
nb=2, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose")
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model.load_state_dict(torch.load(model_path), strict=True)
for key, v in model.named_parameters():
v.requires_grad = False
number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
logger.info('Params number: {}'.format(number_parameters))
model = model.to(device)
logger.info('Model path: {:s}'.format(model_path))
logger.info('model_name:{}, image sigma:{}'.format(model_name, noise_level_img))
logger.info(L_path)
img = os.path.join(L_path, test_image)
# ------------------------------------
# (1) img_L
# ------------------------------------
img_name, ext = os.path.splitext(os.path.basename(img))
img_L = util.imread_uint(img, n_channels=n_channels)
img_L = util.uint2single(img_L)
w, h = img_L.shape[:2]
logger.info('{:>10s}--> ({:>4d}x{:<4d})'.format(img_name+ext, w, h))
# boundary handling
boarder = 8 # default setting for kernel size 25x25
img = cv2.resize(img_L, (sf*h, sf*w), interpolation=cv2.INTER_NEAREST)
img = utils_deblur.wrap_boundary_liu(img, [int(np.ceil(sf*w/boarder+2)*boarder), int(np.ceil(sf*h/boarder+2)*boarder)])
img_wrap = sr.downsample_np(img, sf, center=False)
img_wrap[:w, :h, :] = img_L
img_L = img_wrap
img_L = util.single2tensor4(img_L)
img_L = img_L.to(device)
# ------------------------------------
# (2) img_E
# ------------------------------------
sigma = torch.tensor(noise_level_model).float().view([1, 1, 1, 1])
[img_L, kernel, sigma] = [el.to(device) for el in [img_L, kernel, sigma]]
img_E = model(img_L, kernel, sf, sigma)
img_E = util.tensor2uint(img_E)[:sf*w, :sf*h, ...]
if save_E:
util.imsave(img_E, os.path.join(E_path, img_name+'_x'+str(sf)+'_'+model_name+'.png'))
# --------------------------------
# (3) save img_LE
# --------------------------------
if save_LE:
k_v = k/np.max(k)*1.2
k_v = util.single2uint(np.tile(k_v[..., np.newaxis], [1, 1, 3]))
k_factor = 3
k_v = cv2.resize(k_v, (k_factor*k_v.shape[1], k_factor*k_v.shape[0]), interpolation=cv2.INTER_NEAREST)
img_L = util.tensor2uint(img_L)[:w, :h, ...]
img_I = cv2.resize(img_L, (sf*img_L.shape[1], sf*img_L.shape[0]), interpolation=cv2.INTER_NEAREST)
img_I[:k_v.shape[0], :k_v.shape[1], :] = k_v
util.imshow(np.concatenate([img_I, img_E], axis=1), title='LR / Recovered') if show_img else None
util.imsave(np.concatenate([img_I, img_E], axis=1), os.path.join(E_path, img_name+'_x'+str(sf)+'_'+model_name+'_LE.png'))
if __name__ == '__main__':
main()