import torch
from torch import nn
from torch import autograd
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from models.discriminator import Discriminator_STE
from PIL import Image
import numpy as np
def gram_matrix(feat):
# https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/utils.py
(b, ch, h, w) = feat.size()
feat = feat.view(b, ch, h * w)
feat_t = feat.transpose(1, 2)
gram = torch.bmm(feat, feat_t) / (ch * h * w)
return gram
def visual(image):
im = image.transpose(1,2).transpose(2,3).detach().cpu().numpy()
Image.fromarray(im[0].astype(np.uint8)).show()
def dice_loss(input, target):
input = torch.sigmoid(input)
input = input.contiguous().view(input.size()[0], -1)
target = target.contiguous().view(target.size()[0], -1)
input = input
target = target
a = torch.sum(input * target, 1)
b = torch.sum(input * input, 1) + 0.001
c = torch.sum(target * target, 1) + 0.001
d = (2 * a) / (b + c)
dice_loss = torch.mean(d)
return 1 - dice_loss
class LossWithGAN_STE(nn.Module):
def init(self, logPath, extractor, Lamda, lr, betasInit=(0.5, 0.9)):
super(LossWithGAN_STE, self).init()
self.l1 = nn.L1Loss()
self.extractor = extractor
self.discriminator = Discriminator_STE(3) ## local_global sn patch gan
self.D_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=betasInit)
self.cudaAvailable = torch.cuda.is_available()
self.numOfGPUs = torch.cuda.device_count()
self.lamda = Lamda
self.writer = SummaryWriter(logPath)
def forward(self, input, mask, x_o1,x_o2,x_o3,output,mm, gt, count, epoch):
self.discriminator.zero_grad()
D_real = self.discriminator(gt, mask)
D_real = D_real.mean().sum() * -1
D_fake = self.discriminator(output, mask)
D_fake = D_fake.mean().sum() * 1
D_loss = torch.mean(F.relu(1.+D_real)) + torch.mean(F.relu(1.+D_fake)) #SN-patch-GAN loss
D_fake = -torch.mean(D_fake) # SN-Patch-GAN loss
self.D_optimizer.zero_grad()
D_loss.backward(retain_graph=True)
self.D_optimizer.step()
self.writer.add_scalar('LossD/Discrinimator loss', D_loss.item(), count)
output_comp = mask * input + (1 - mask) * output
# import pdb;pdb.set_trace()
holeLoss = 10 * self.l1((1 - mask) * output, (1 - mask) * gt)
validAreaLoss = 2*self.l1(mask * output, mask * gt)
mask_loss = dice_loss(mm, 1-mask)
### MSR loss ###
masks_a = F.interpolate(mask, scale_factor=0.25)
masks_b = F.interpolate(mask, scale_factor=0.5)
imgs1 = F.interpolate(gt, scale_factor=0.25)
imgs2 = F.interpolate(gt, scale_factor=0.5)
msrloss = 8 * self.l1((1-mask)*x_o3,(1-mask)*gt) + 0.8*self.l1(mask*x_o3, mask*gt)+\
6 * self.l1((1-masks_b)*x_o2,(1-masks_b)*imgs2)+1*self.l1(masks_b*x_o2,masks_b*imgs2)+\
5 * self.l1((1-masks_a)*x_o1,(1-masks_a)*imgs1)+0.8*self.l1(masks_a*x_o1,masks_a*imgs1)
feat_output_comp = self.extractor(output_comp)
feat_output = self.extractor(output)
feat_gt = self.extractor(gt)
prcLoss = 0.0
for i in range(3):
prcLoss += 0.01 * self.l1(feat_output[i], feat_gt[i])
prcLoss += 0.01 * self.l1(feat_output_comp[i], feat_gt[i])
styleLoss = 0.0
for i in range(3):
styleLoss += 120 * self.l1(gram_matrix(feat_output[i]),
gram_matrix(feat_gt[i]))
styleLoss += 120 * self.l1(gram_matrix(feat_output_comp[i]),
gram_matrix(feat_gt[i]))
""" if self.numOfGPUs > 1:
holeLoss = holeLoss.sum() / self.numOfGPUs
validAreaLoss = validAreaLoss.sum() / self.numOfGPUs
prcLoss = prcLoss.sum() / self.numOfGPUs
styleLoss = styleLoss.sum() / self.numOfGPUs """
self.writer.add_scalar('LossG/Hole loss', holeLoss.item(), count)
self.writer.add_scalar('LossG/Valid loss', validAreaLoss.item(), count)
self.writer.add_scalar('LossG/msr loss', msrloss.item(), count)
self.writer.add_scalar('LossPrc/Perceptual loss', prcLoss.item(), count)
self.writer.add_scalar('LossStyle/style loss', styleLoss.item(), count)
GLoss = msrloss+ holeLoss + validAreaLoss+ prcLoss + styleLoss + 0.1 * D_fake + 1*mask_loss
self.writer.add_scalar('Generator/Joint loss', GLoss.item(), count)
return GLoss
train.py:
import os
import math
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from PIL import Image
import numpy as np
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import utils
from data.dataloader import ErasingData
from loss.Loss import LossWithGAN_STE
from models.Model import VGG16FeatureExtractor
from models.sa_gan import STRnet2
torch.set_num_threads(5)
#os.environ[“CUDA_VISIBLE_DEVICES”] = “3” ### set the gpu as No…
parser = argparse.ArgumentParser()
parser.add_argument(‘–numOfWorkers’, type=int, default=0,
help=‘workers for dataloader’)
parser.add_argument(‘–modelsSavePath’, type=str, default=‘./checkpoint’,
help=‘path for saving models’)
parser.add_argument(‘–logPath’, type=str,
default=‘’)
parser.add_argument(‘–batchSize’, type=int, default=3)
parser.add_argument(‘–loadSize’, type=int, default=512,
help=‘image loading size’)
parser.add_argument(‘–dataRoot’, type=str, default=“”)
parser.add_argument(‘–pretrained’,type=str, default=‘’, help=‘pretrained models for finetuning’)
parser.add_argument(‘–num_epochs’, type=int, default=500, help=‘epochs’)
args = parser.parse_args()
torch.autograd.set_detect_anomaly(True)
def visual(image):
im = image.transpose(1,2).transpose(2,3).detach().cpu().numpy()
Image.fromarray(im[0].astype(np.uint8)).show()
cuda = torch.cuda.is_available()
if cuda:
print(‘Cuda is available!’)
cudnn.enable = True
cudnn.benchmark = True
batchSize = args.batchSize
loadSize = (args.loadSize, args.loadSize)
if not os.path.exists(args.modelsSavePath):
os.makedirs(args.modelsSavePath)
dataRoot = args.dataRoot
import pdb;pdb.set_trace()
Erase_data = ErasingData(dataRoot, loadSize, training=True)
Erase_data = DataLoader(Erase_data, batch_size=batchSize,
shuffle=True, num_workers=args.numOfWorkers, drop_last=False, pin_memory=True)
device = torch.device(‘cuda:0’)
netG = STRnet2(3)
if args.pretrained != ‘’:
print('loaded ')
netG.load_state_dict(torch.load(args.pretrained))
numOfGPUs = torch.cuda.device_count()
if cuda:
netG = netG.to(device)
if numOfGPUs > 3:
netG = nn.DataParallel(netG, device_ids=range(numOfGPUs))
count = 1
G_optimizer = optim.Adam(netG.parameters(), lr=0.0001, betas=(0.5, 0.9))
criterion = LossWithGAN_STE(args.logPath, VGG16FeatureExtractor(), lr=0.00001, betasInit=(0.0, 0.9), Lamda=10.0)
if cuda:
criterion = criterion.to(device)
if numOfGPUs > 3:
criterion = nn.DataParallel(criterion, device_ids=range(numOfGPUs))
print(‘OK!’)
num_epochs = args.num_epochs
for i in range(1, num_epochs + 1):
netG.train()
for k,(imgs, gt, masks, path) in enumerate(Erase_data):
if cuda:
imgs = imgs.to(device)
gt = gt.to(device)
masks = masks.to(device)
netG.zero_grad()
x_o1,x_o2,x_o3,fake_images,mm = netG(imgs)
G_loss = criterion(imgs, masks, x_o1, x_o2, x_o3, fake_images, mm, gt, count, i)
G_loss = G_loss.sum()
#G_loss = torch.sum(G_loss)
G_optimizer.zero_grad()
G_loss.backward()
G_optimizer.step()
print('[{}/{}] Generator Loss of epoch{} is {}'.format(k,len(Erase_data),i, G_loss.item()))
count += 1
if ( i % 10 == 0):
if numOfGPUs > 1 :
torch.save(netG.module.state_dict(), args.modelsSavePath +
'/STE_{}.pth'.format(i))
else:
torch.save(netG.state_dict(), args.modelsSavePath +
'/STE_{}.pth'.format(i))
when run G_loss.backward() raise error as follows
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 512, 4, 4]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!