Can't figure out how to integrate DataParallel to this workflow

Dear All,

I am trying to run the main_test_real_application.py file from this GIT repo: GitHub - cszn/USRNet: Deep Unfolding Network for Image Super-Resolution (CVPR, 2020) (PyTorch). With a low resoltion image as input, the script runs without any problem.
The problem I am encountering is that when I input a relatively big image and set scale factor to 4, my GPU(0) runs out of memory.

I have two Tesla K80 and I have tried to use DataParallel in order to make use of the whole memory available on my server but I couldn’t manage to get it work.

My question is: could you show me how to properly impelement DataParallel on this script?

Thanks a lot

Alex

import os.path
import cv2
import logging

import numpy as np
from datetime import datetime
from collections import OrderedDict
from scipy.io import loadmat
from scipy import ndimage
import scipy.io as scio

import torch

from utils import utils_deblur
from utils import utils_logger
from utils import utils_sisr as sr
from utils import utils_image as util

from models.network_usrnet import USRNet as net



def main():

    # ----------------------------------------
    # Preparation
    # ----------------------------------------
    model_name = 'usrgan'      # 'usrgan' | 'usrnet' | 'usrgan_tiny' | 'usrnet_tiny'
    testset_name = 'set_real'  # test set,  'set_real'
    test_image = 'image1.jpg'    # 'chip.png', 'comic.png'
    #test_image = 'comic.png'

    sf = 4                  # scale factor, only from {1, 2, 3, 4}
    show_img = False           # default: False
    save_E = True              # save estimated image
    save_LE = False             # save zoomed LR, Estimated images

    
    
    noise_level_img = 2       # noise level for LR image, 0.5~3 for clean images
    kernel_width_default_x1234 = [0.4, 0.7, 1.5, 2.0] # default Gaussian kernel widths of clean/sharp images for x1, x2, x3, x4

    noise_level_model = noise_level_img/255.  # noise level of model
    kernel_width = kernel_width_default_x1234[sf-1]

    k = utils_deblur.fspecial('gaussian', 25, kernel_width)
    k = sr.shift_pixel(k, sf)  # shift the kernel
    k /= np.sum(k)
    util.surf(k) if show_img else None
    
    kernel = util.single2tensor4(k[..., np.newaxis])


    n_channels = 1 if 'gray' in  model_name else 3  # 3 for color image, 1 for grayscale image
    model_pool = 'model_zoo'  # fixed
    testsets = 'testsets'     # fixed
    results = 'results'       # fixed
    result_name = testset_name + '_' + model_name
    model_path = os.path.join(model_pool, model_name+'.pth')

    # ----------------------------------------
    # L_path, E_path
    # ----------------------------------------
    L_path = os.path.join(testsets, testset_name) # L_path, fixed, for Low-quality images
    E_path = os.path.join(results, result_name)   # E_path, fixed, for Estimated images
    util.mkdir(E_path)

    logger_name = result_name
    utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
    logger = logging.getLogger(logger_name)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ----------------------------------------
    # load model
    # ----------------------------------------
    if 'tiny' in model_name:
        model = net(n_iter=6, h_nc=32, in_nc=4, out_nc=3, nc=[16, 32, 64, 64],
                    nb=2, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose")
    else:
        model = net(n_iter=8, h_nc=64, in_nc=4, out_nc=3, nc=[64, 128, 256, 512],
                    nb=2, act_mode="R", downsample_mode='strideconv', upsample_mode="convtranspose")
    

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")


    model.load_state_dict(torch.load(model_path), strict=True)
    
   
    for key, v in model.named_parameters():
        v.requires_grad = False

    number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
    logger.info('Params number: {}'.format(number_parameters))
    model = model.to(device)
    logger.info('Model path: {:s}'.format(model_path))

    logger.info('model_name:{}, image sigma:{}'.format(model_name, noise_level_img))
    logger.info(L_path)

    img = os.path.join(L_path, test_image)
    # ------------------------------------
    # (1) img_L
    # ------------------------------------
    img_name, ext = os.path.splitext(os.path.basename(img))
    img_L = util.imread_uint(img, n_channels=n_channels)
    img_L = util.uint2single(img_L)

    
    w, h = img_L.shape[:2]
    logger.info('{:>10s}--> ({:>4d}x{:<4d})'.format(img_name+ext, w, h))

    # boundary handling
    boarder = 8     # default setting for kernel size 25x25
    img = cv2.resize(img_L, (sf*h, sf*w), interpolation=cv2.INTER_NEAREST)
    img = utils_deblur.wrap_boundary_liu(img, [int(np.ceil(sf*w/boarder+2)*boarder), int(np.ceil(sf*h/boarder+2)*boarder)])
    img_wrap = sr.downsample_np(img, sf, center=False)
    img_wrap[:w, :h, :] = img_L
    img_L = img_wrap

    img_L = util.single2tensor4(img_L)
    
    img_L = img_L.to(device)
    
    # ------------------------------------
    # (2) img_E
    # ------------------------------------
    sigma = torch.tensor(noise_level_model).float().view([1, 1, 1, 1])
    
    [img_L, kernel, sigma] = [el.to(device) for el in [img_L, kernel, sigma]]
    
    img_E = model(img_L, kernel, sf, sigma)
    
    img_E = util.tensor2uint(img_E)[:sf*w, :sf*h, ...]

    if save_E:
        util.imsave(img_E, os.path.join(E_path, img_name+'_x'+str(sf)+'_'+model_name+'.png'))

    # --------------------------------
    # (3) save img_LE
    # --------------------------------
    if save_LE:
        k_v = k/np.max(k)*1.2
        k_v = util.single2uint(np.tile(k_v[..., np.newaxis], [1, 1, 3]))
        k_factor = 3
        k_v = cv2.resize(k_v, (k_factor*k_v.shape[1], k_factor*k_v.shape[0]), interpolation=cv2.INTER_NEAREST)
        img_L = util.tensor2uint(img_L)[:w, :h, ...]
        img_I = cv2.resize(img_L, (sf*img_L.shape[1], sf*img_L.shape[0]), interpolation=cv2.INTER_NEAREST)
        img_I[:k_v.shape[0], :k_v.shape[1], :] = k_v
        util.imshow(np.concatenate([img_I, img_E], axis=1), title='LR / Recovered') if show_img else None
        util.imsave(np.concatenate([img_I, img_E], axis=1), os.path.join(E_path, img_name+'_x'+str(sf)+'_'+model_name+'_LE.png'))


if __name__ == '__main__':

    main()

If a single sample yields an OOM error on a single device, nn.DataParallel won’t save you from this.
As the name suggests, a data parallel approach will be used, which will split the input batch in dim0 and send each chunk to each specified device.
If you want to use model sharding (model parallel), you could take a look at this post.

Hi ptrblck,

Thanks for your answer, I was actually hoping for an answer from you :-).

Now I understand why what I have tried does not work.

I’ve had a look at the post your have gently shared, would Model parallel solve the issue?
If I understand well, I have to spilt my input data (image) In two/four, send the splitter data to the different gpus and merge the data after process completion is that correct?

Could you show me how you would use Model parallel in this precise case? My input image is a jpg image.

Thanks a lot

No, this would be the the data parallel approach.
For model sharding, you could keep only parts of the model on each device and would transfer the forward activation to the corresponding devices.

My example code snippet shows how the model should be manipulated for a (manual) model sharding.
However, I think higher level APIs such as Ignite, Lightning, Catalyst etc. would already have an (automatic) utility for it, so you could also see how they’ve implemented it.