Hello. Thank you very much for a great project; Pytorch is really making my deep-learning dreams come true. This is something that I’ve wanted to do for a while.
I was hoping for some direction on my project. Once I know I can make the pipeline flow smoothly, I’ll spend more time getting acquainted with the details .
I’ve trained a Pix2PixHD model for image translation with no labels on 512x512 paired images, and the model is amazing.
I’m trying to produce an input to the Pix2PixHD model which results in an output with a particular visual style (e.g. that of Starry Night).
The dataset in the enumerate loop just contains a 512x512 noise PNG. It is passed to
a “run_style_transfer” modified from the tutorials.
The vgg_model allows me to optimise the style of the output of the Pix2PixHD inference, but I don’t appear to be able to propagate the changes back to the input image. I’d appreciate a point in the right direction and am really looking forward to learning a lot more with Pytorch!
import os
from collections import OrderedDict
from torch.autograd import Variable
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import util.util as util
from util.visualizer import Visualizer
from util import html
import torch
from PIL import Image
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from torchvision.utils import save_image
import copy
opt = TestOptions().parse(save=False)
opt.nThreads = 1 # test code only supports nThreads = 1
opt.batchSize = 1 # test code only supports batchSize = 1
opt.serial_batches = True # no shuffle
opt.no_flip = True # no flip
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
pix_model = create_model(opt)
for param in pix_model.parameters():
param.requires_grad=False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
imsize = 512
loader = transforms.Compose([
transforms.Resize(imsize), # scale imported image
transforms.ToTensor()]) # transform it into a torch tensor
def image_loader(image_name):
image = Image.open(image_name)
# fake batch dimension required to fit network's input dimensions
image = loader(image).unsqueeze(0)
return image.to(device, torch.float)
style_img = image_loader("style_images/starry_night.png")
#style_img = image_loader("style_images/Picasso2.jpg")
# style_img = image_loader("style_images/happy-nature-maciek-froncisz.png")
def gram_matrix(input):
print(input.size())
a, b, c, d = input.size()
features = input.view(a * b, c * d)
G = torch.mm(features, features.t())
return G.div(a * b * c * d)
class StyleLoss(nn.Module):
def __init__(self, target_feature):
super(StyleLoss, self).__init__()
self.target = gram_matrix(target_feature).detach()
def forward(self, input):
G = gram_matrix(input)
self.loss = F.mse_loss(G, self.target)
return input
cnn = models.vgg19(pretrained=True).features.to(device).eval()
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
# create a module to normalize input image so we can easily put it in a
# nn.Sequential
class Normalization(nn.Module):
def __init__(self, mean, std):
super(Normalization, self).__init__()
# .view the mean and std to make them [C x 1 x 1] so that they can
# directly work with image Tensor of shape [B x C x H x W].
# B is batch size. C is number of channels. H is height and W is width.
self.mean = torch.tensor(mean).view(-1, 1, 1)
self.std = torch.tensor(std).view(-1, 1, 1)
def forward(self, img):
# normalize img
return (img - self.mean) / self.std
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
# style_layers_default = ['conv_1', 'conv_5']
# style_layers_default = ['conv_5']
def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
style_img, style_layers=style_layers_default):
cnn = copy.deepcopy(cnn)
normalization = Normalization(normalization_mean, normalization_std).to(device)
# just in order to have an iterable access to or list of content/syle
# losses
style_losses = []
# assuming that cnn is a nn.Sequential, so we make a new nn.Sequential
# to put in modules that are supposed to be activated sequentially
model = nn.Sequential(normalization)
i = 0 # increment every time we see a conv
for layer in cnn.children():
if isinstance(layer, nn.Conv2d):
i += 1
name = 'conv_{}'.format(i)
elif isinstance(layer, nn.ReLU):
name = 'relu_{}'.format(i)
# The in-place version doesn't play very nicely with the ContentLoss
# and StyleLoss we insert below. So we replace with out-of-place
# ones here.
layer = nn.ReLU(inplace=False)
elif isinstance(layer, nn.MaxPool2d):
name = 'pool_{}'.format(i)
elif isinstance(layer, nn.BatchNorm2d):
name = 'bn_{}'.format(i)
else:
raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
model.add_module(name, layer)
if name in style_layers:
# add style loss:
target_feature = model(style_img).detach()
style_loss = StyleLoss(target_feature)
model.add_module("style_loss_{}".format(i), style_loss)
style_losses.append(style_loss)
# now we trim off the layers after the last content and style losses
for i in range(len(model) - 1, -1, -1):
#if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
if isinstance(model[i], StyleLoss):
break
model = model[:(i + 1)]
return model, style_losses
def get_input_optimizer(input_img):
# this line to show that input is a parameter that requires a gradient
optimizer = optim.LBFGS([input_img.requires_grad_()])
return optimizer
def run_style_transfer(cnn, normalization_mean, normalization_std, style_img, input_img, pix_model, data_inst, data_label, num_steps=50000,
style_weight=1000000, content_weight=1):
"""Run the style transfer."""
print('Building the style transfer model..')
input_img = input_img.to(device)
pix_model = copy.deepcopy(pix_model).to(device)
pix_model.eval()
vgg_model, style_losses = get_style_model_and_losses(cnn,
normalization_mean, normalization_std, style_img)
optimizer = get_input_optimizer(input_img)
merge_model =
print('Optimizing..')
run = [0]
while run[0] <= num_steps:
def closure():
# correct the values of updated input image
input_img.data.clamp_(0, 1)
optimizer.zero_grad()
generated = pix_model.inference(input_img, data_inst, data_label)
# save_image(generated[0].unsqueeze(0), 'image_test.png')
vgg_model(generated[0].unsqueeze(0))
vgg_model(input_img)
style_score = 0
for sl in style_losses:
style_score += sl.loss
style_score *= style_weight
loss = style_score
loss.backward()
run[0] += 1
if run[0] % 50 == 0:
print("run {}:".format(run))
print('Style Loss : {:4f}'.format(
style_score.item()))
print()
return style_score
optimizer.step(closure)
# a last correction...
input_img.data.clamp_(0, 1)
return input_img
for i, data in enumerate(dataset):
# save_image(data['label'], 'image_test.png')
output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std, style_img=style_img, input_img=data['label'], data_inst=data['inst'], data_label=data['label'], num_steps=2000, pix_model=pix_model)
save_image(output, 'image_test.png')