well, it dosnt work yet.
here is the code fully code (it is pretty amateurish
):
import argparse
import logging
import os
import sys
import cv2
import SSIM.pytorch_ssim as ssim
import torchvision
# import numpy as np
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm
from numpy import linalg as LA
from eval import eval_net
from unet import UNet, DeepLabv3_plus, deeplab_v3_separation
# import coco_annotation
from torch.utils.tensorboard import SummaryWriter
from utils.dataset import BasicDataset
from torch.utils.data import DataLoader, random_split
dir_img = 'data/train2017/train2017_640480/'
dir_mask = 'data/train2017/train2017_640480_anns/'
dir_checkpoint = 'checkpoints/'
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
def train_net(net,
device,
epochs=5,
batch_size=2,
lr=0.001,
val_percent=0.1,
save_cp=True,
img_scale=0.25):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)
writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
global_step = 0
logging.info(f'''Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {lr}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_cp}
Device: {device.type}
Images scaling: {img_scale}
''')
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.1)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' , patience=2) # original: 'min' if net.n_classes > 1 else 'max'
# if net.n_classes > 1:
# criterion = nn.CrossEntropyLoss()
# else:
# criterion = nn.CrossEntropyLoss()
# # criterion = nn.BCEWithLogitsLoss()
L1_criterion = nn.L1Loss()
semantic_criterion = nn.CrossEntropyLoss(reduction='none')
for epoch in range(epochs):
net.train()
# epoch_semantic_loss = 0
epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
for batch in train_loader:
# print('Test Test Test')
mixedImgs = batch['mixedImage']
if mixedImgs.size()[0]==1:
continue
BackgroundImage = batch['BackgroundImage']
ReflectionImage = batch['ReflectionImage']
BackgtoundSemantics = batch['BackgtoundSemantics']
assert mixedImgs.shape[1] == 3, \
f'Network has been defined with {net.n_channels} input channels, ' \
f'but loaded images have {mixedImgs.shape[1]} channels. Please check that ' \
'the images are loaded correctly.'
mixedImgs = mixedImgs.to(device=device) #, dtype=torch.float32)
BackgroundImage = BackgroundImage.to(device=device) #, dtype=torch.float32)*255
ReflectionImage = ReflectionImage.to(device=device) #, dtype=torch.float32)
BackgtoundSemantics = BackgtoundSemantics.to(device=device , dtype=torch.long) # , dtype=torch.long
semantic_pred, BG_img_pred, R_img_pred = net(mixedImgs)
varR = torch.var(torch.abs(R_img_pred-ReflectionImage)) #dim=0, keepdim=True,
# semLos = -BackgtoundSemantics*torch.log10(semantic_pred)-(1-BackgtoundSemantics)*torch.log10(1-semantic_pred)
semLos = semantic_criterion(torch.abs(semantic_pred[:,:,:,:]),BackgtoundSemantics[:,0,:,:] )
varS = torch.var(semLos)
loss = L1_criterion(BG_img_pred, BackgroundImage)
# loss += 0.0003*LA.norm(cv2.Canny(BG_img_pred,100,200) - cv2.Canny(BackgroundImage,100,200))
loss += 0.6*(1 - ssim.ssim(BG_img_pred, BackgroundImage))
loss += 0.8*L1_criterion(R_img_pred, ReflectionImage)
loss = loss/(2*varR.square())
loss += semantic_criterion(semantic_pred[:,:,:,:],BackgtoundSemantics[:,0,:,:])/(2*varS.square())
epoch_loss += loss.item()
writer.add_scalar('Loss/train', loss.item(), global_step)
pbar.set_postfix(**{'loss (batch)': loss.item()})
optimizer.zero_grad()
loss.backward()
# semantic_loss.backward()
nn.utils.clip_grad_value_(net.parameters(), 0.1)
optimizer.step()
pbar.update(mixedImgs.shape[0])
global_step += 1
if global_step % (len(dataset) // (10 * batch_size)) == 0:
for tag, value in net.named_parameters():
tag = tag.replace('.', '/')
writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)
val_score = eval_net(net, val_loader, device)
# scheduler.step(val_score)
writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)
if 1 > 1: #net.n_classes
logging.info('Validation cross entropy: {}'.format(val_score))
writer.add_scalar('Loss/test', val_score, global_step)
else:
logging.info('Validation Dice Coeff: {}'.format(val_score))
writer.add_scalar('Dice/test', val_score, global_step)
writer.add_images('images', mixedImgs, global_step)
# if net.n_classes == 1:
# writer.add_images('masks/true', BG_img_pred, global_step)
# writer.add_images('masks/pred', torch.sigmoid(imgs_pred) > 0.5, global_step)
if save_cp:
try:
os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory')
except OSError:
pass
torch.save(net.state_dict(),
dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
logging.info(f'Checkpoint {epoch + 1} saved !')
writer.close()
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5,
help='Number of epochs', dest='epochs')
parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=2,
help='Batch size', dest='batchsize')
parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.1,
help='Learning rate', dest='lr')
parser.add_argument('-f', '--load', dest='load', type=str, default=False,
help='Load model from a .pth file')
parser.add_argument('-s', '--scale', dest='scale', type=float, default=0.25,
help='Downscaling factor of the images')
parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0,
help='Percent of the data that is used as validation (0-100)')
return parser.parse_args()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
args = get_args()
# args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
# - For 1 class and background, use n_classes=1
# - For 2 classes, use n_classes=1
# - For N > 2 classes, use n_classes=N
net = deeplab_v3_separation() #UNet(n_channels=3, n_classes=3, bilinear=True)
# logging.info(f'Network:\n'
# f'\t{net.n_channels} input channels\n'
# f'\t{net.n_classes} output channels (classes)\n'
# f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
if args.load:
net.load_state_dict(
torch.load(args.load, map_location=device)
)
logging.info(f'Model loaded from {args.load}')
net.to(device=device)
# faster convolutions, but more memory
# cudnn.benchmark = True
try:
train_net(net=net,
epochs=args.epochs,
batch_size=args.batchsize,
lr=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
try:
sys.exit(0)
except SystemExit:
os._exit(0)
this is the model:
# -*- coding: utf-8 -*-
"""
Created on Fri Aug 28 14:09:04 2020
@author: aviel
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
import torchvision.models as models
class deeplab_v3_separation(nn.Module):
def __init__(self):
super(deeplab_v3_separation, self).__init__()
deepLabV3ResNet101 = models.segmentation.deeplabv3_resnet101(pretrained=True)
features_convs_and_ASSP = list(deepLabV3ResNet101.children())
ResNet101 = models.resnet101(pretrained=True)
ResNet101_features = list(ResNet101.children())[:-2]
ASSP_layers = models.segmentation.deeplabv3_resnet101(pretrained=True)
ASSP_features = list(ASSP_layers.children())
self.ResNet101_features = nn.Sequential(*ResNet101_features)
self.ASSP_features = nn.Sequential(*ASSP_features[1])
self.Semantic_convs = nn.Conv2d(2069, 92, 1)
self.semanticSoftmax = nn.Softmax2d()
self.Background_convs = nn.Conv2d(2164, 3, 1)
self.Separation_convs = nn.Conv2d(2167, 3, 1)
# self.up = nn.ConvTranspose2d(2069 , 2069, kernel_size=2, stride=2)
def forward(self, x):
input_shape = x.shape[-2:]
x1 = self.ResNet101_features(x)
x2 = self.ASSP_features(x1)
x3 = torch.cat((x1,x2),1)
x3 = F.interpolate(x3, size=input_shape, mode='bilinear', align_corners=False)
y1 = self.Semantic_convs(x3)
# y1 = self.semanticSoftmax(y1)
y2 = self.Background_convs(torch.cat((x,x3,y1),1))
y3 = self.Separation_convs(torch.cat((x,y1,y2,x3),1))
return y1, y2, y3
and the train loader:
import os
from os.path import splitext
from os import listdir
import numpy as np
from glob import glob
import torch
from torch.utils.data import Dataset
import logging
from PIL import Image
class BasicDataset(Dataset):
def __init__(self, imgs_dir, masks_dir, scale=1):
self.imgs_dir = imgs_dir
self.masks_dir = masks_dir
self.scale = scale
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
if not file.startswith('.')]
logging.info(f'Creating dataset with {len(self.ids)} examples')
def __len__(self):
return len(self.ids)
@classmethod
def preprocess(cls, pil_img, scale):
w, h = pil_img.size
newW, newH = int(scale * w), int(scale * h)
assert newW > 0 and newH > 0, 'Scale is too small'
pil_img = pil_img.resize((newW, newH))
img_nd = np.array(pil_img)
if len(img_nd.shape) == 2:
img_nd = np.expand_dims(img_nd, axis=2)
# HWC to CHW
img_trans = img_nd.transpose((2, 0, 1))
if img_trans.max() > 1:
img_trans = img_trans / 255
return img_trans
def semantic_preprocess(cls, pil_img, scale):
w, h = pil_img.size
newW, newH = int(scale * w), int(scale * h)
assert newW > 0 and newH > 0, 'Scale is too small'
pil_img = pil_img.resize((newW, newH))
img_nd = np.array(pil_img)
if len(img_nd.shape) == 2:
img_nd = np.expand_dims(img_nd, axis=2)
# HWC to CHW
img_trans = img_nd.transpose((2, 0, 1))
# if img_trans.max() > 1:
# img_trans = img_trans / 255
return img_trans
def __getitem__(self, i):
idx = self.ids[i]
idx2 = self.ids[int(len(self.ids)*torch.rand(1))]
img_file = glob(self.imgs_dir + idx + '.*')
img_file2 = glob(self.imgs_dir + idx2 + '.*')
assert len(img_file) == 1, \
f'Either no mask or multiple masks found for the ID {idx}: {img_file}'
assert len(img_file2) == 1, \
f'Either no image or multiple images found for the ID {idx}: {img_file2}'
img = Image.open(img_file[0])
img2 = Image.open(img_file2[0])
while img.mode == 'L' :
i+=1
idx = self.ids[i]
img_file = glob(self.imgs_dir + idx + '.*')
assert len(img_file) == 1, \
f'Either no mask or multiple masks found for the ID {idx}: {img_file}'
img = Image.open(img_file[0])
while img2.mode == 'L' :
idx2 = self.ids[int(len(self.ids)*torch.rand(1))]
img_file2 = glob(self.imgs_dir + idx2 + '.*')
assert len(img_file2) == 1, \
f'Either no image or multiple images found for the ID {idx}: {img_file2}'
img2 = Image.open(img_file2[0])
assert img.size == img2.size, \
f'Image and mask {idx} should be the same size, but are {img.size} and {img2.size}'
img = self.preprocess(img, self.scale)
img2 = self.preprocess(img2, self.scale)
mask_file = glob(self.masks_dir + idx + '.*')
BackgtoundSemantics = Image.open(mask_file[0])
BackgtoundSemantics = self.semantic_preprocess(BackgtoundSemantics, self.scale)
# BackgtoundSemantics = np.array(BackgtoundSemantics)
if BackgtoundSemantics.max() > 1:
BackgtoundSemantics = BackgtoundSemantics-92
# img = np.array(img)/255
# img2 = np.array(img2)/255
# outPutImages = (np.append(img,img2,axis=0))
BackgroundCoeff = 0.5+0.1*np.random.randint(5,size=1)
ReflectionCoeff = 1-BackgroundCoeff
mixedImage = np.array(BackgroundCoeff*img + ReflectionCoeff*img2)
return {
'mixedImage': torch.from_numpy(mixedImage).type(torch.float32),
'BackgroundImage': torch.from_numpy(img).type(torch.float32),
'ReflectionImage': torch.from_numpy(img2).type(torch.float32),
'BackgtoundSemantics': torch.from_numpy(BackgtoundSemantics).type(torch.float32)
}