I am using two Nvidia-Quadro 1200(4gb) gpu for inferencing an image of size(1024*1792) in UNET segmentation using Pytorch Dataparallel method. Even though the code will start the inference it will go to only one gpu and other will remain idle. Since it is only going to single gpu I am getting the CUDA OOM error every time. I cant compromise the image size because it will alter our objective. I will add the code below.
import torch
import cv2
import numpy as np
from model import build_unet
from torch.nn.parallel import DataParallel
import os
from tqdm import tqdm
torch.cuda.empty_cache()
checkpoint_path = "Weights/best_model.pth"
def mask_parse(mask):
mask = np.expand_dims(mask, axis=-1)
mask = np.concatenate(\[mask, mask, mask\], axis=-1)
return mask
def inference_with_dataparallel(model, cv_img, imname):
image = cv_img
imageCopy = image.copy()
x = np.transpose(image, (2, 0, 1))
x = x / 255.0
x = np.expand_dims(x, axis=0)
x = x.astype(np.float32)
x = torch.from_numpy(x)
x = x.to(device)
with torch.no_grad():
pred_y = model(x)
pred_y = torch.sigmoid(pred_y)
pred_y = pred_y[0].cpu().numpy()
pred_y = np.squeeze(pred_y, axis=0)
pred_y = pred_y > 0.1 # 0.15
pred_y = np.array(pred_y, dtype=np.uint8)
pred_y = mask_parse(pred_y)
out = pred_y * 255
imageCopy = cv2.resize(imageCopy, (mWidth, mHeight))
out = cv2.resize(out, (mWidth, mHeight))
finalOut = cv2.addWeighted(imageCopy, 0.6, out, 0.4, 0)
cv2.imwrite('Output/' + imname, finalOut)
if __name__ == "__main__":
model = build_unet()
device = torch.device("cuda")
model = model.to(device)
model = DataParallel(model)
model.load_state_dict(torch.load(checkpoint_path,map_location="cuda"))
dirListCSV=os.listdir('TestImages')
for allCSV in tqdm(dirListCSV):
imgName = 'TestImages/'+ allCSV
cv_img = cv2.imread(imgName)
mHeight, mWidth = cv_img.shape[0], cv_img.shape[1]
cv_img = cv2.resize(cv_img, (1024, 1792))
inference_with_dataparallel(model, cv_img, allCSV)
print('Gpu0:',torch.cuda.max_memory_allocated(device=device_ids[0]))
print('Gpu1:',torch.cuda.max_memory_allocated(device=device_ids[1]))