Train input of Generator based on Style of output

Hello. Thank you very much for a great project; Pytorch is really making my deep-learning dreams come true. This is something that I’ve wanted to do for a while.

I was hoping for some direction on my project. Once I know I can make the pipeline flow smoothly, I’ll spend more time getting acquainted with the details :smile:.

I’ve trained a Pix2PixHD model for image translation with no labels on 512x512 paired images, and the model is amazing.

I’m trying to produce an input to the Pix2PixHD model which results in an output with a particular visual style (e.g. that of Starry Night).

The dataset in the enumerate loop just contains a 512x512 noise PNG. It is passed to
a “run_style_transfer” modified from the tutorials.

The vgg_model allows me to optimise the style of the output of the Pix2PixHD inference, but I don’t appear to be able to propagate the changes back to the input image. I’d appreciate a point in the right direction and am really looking forward to learning a lot more with Pytorch!

import os
from collections import OrderedDict
from torch.autograd import Variable
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import util.util as util
from util.visualizer import Visualizer
from util import html
import torch

from PIL import Image
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.models as models
from torchvision.utils import save_image

import copy

opt = TestOptions().parse(save=False)
opt.nThreads = 1   # test code only supports nThreads = 1
opt.batchSize = 1  # test code only supports batchSize = 1
opt.serial_batches = True  # no shuffle
opt.no_flip = True  # no flip

data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
pix_model = create_model(opt)

for param in pix_model.parameters():
    param.requires_grad=False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

imsize = 512

loader = transforms.Compose([
    transforms.Resize(imsize),  # scale imported image
    transforms.ToTensor()])  # transform it into a torch tensor

def image_loader(image_name):
    image = Image.open(image_name)
    # fake batch dimension required to fit network's input dimensions
    image = loader(image).unsqueeze(0)
    return image.to(device, torch.float)



style_img = image_loader("style_images/starry_night.png")

#style_img = image_loader("style_images/Picasso2.jpg")

# style_img = image_loader("style_images/happy-nature-maciek-froncisz.png")

def gram_matrix(input):
    print(input.size())
    a, b, c, d = input.size()
    features = input.view(a * b, c * d)
    G = torch.mm(features, features.t())
    return G.div(a * b * c * d)

class StyleLoss(nn.Module):

    def __init__(self, target_feature):
        super(StyleLoss, self).__init__()
        self.target = gram_matrix(target_feature).detach()

    def forward(self, input):
        G = gram_matrix(input)
        self.loss = F.mse_loss(G, self.target)
        return input

cnn = models.vgg19(pretrained=True).features.to(device).eval()
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)

# create a module to normalize input image so we can easily put it in a
# nn.Sequential
class Normalization(nn.Module):
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        # .view the mean and std to make them [C x 1 x 1] so that they can
        # directly work with image Tensor of shape [B x C x H x W].
        # B is batch size. C is number of channels. H is height and W is width.
        self.mean = torch.tensor(mean).view(-1, 1, 1)
        self.std = torch.tensor(std).view(-1, 1, 1)

    def forward(self, img):
        # normalize img

        return (img - self.mean) / self.std

style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
# style_layers_default = ['conv_1', 'conv_5']
# style_layers_default = ['conv_5']

def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
                               style_img, style_layers=style_layers_default):
    cnn = copy.deepcopy(cnn)
    normalization = Normalization(normalization_mean, normalization_std).to(device)

    # just in order to have an iterable access to or list of content/syle
    # losses
    style_losses = []

    # assuming that cnn is a nn.Sequential, so we make a new nn.Sequential
    # to put in modules that are supposed to be activated sequentially
    model = nn.Sequential(normalization)

    i = 0  # increment every time we see a conv
    for layer in cnn.children():
        if isinstance(layer, nn.Conv2d):
            i += 1
            name = 'conv_{}'.format(i)
        elif isinstance(layer, nn.ReLU):
            name = 'relu_{}'.format(i)
            # The in-place version doesn't play very nicely with the ContentLoss
            # and StyleLoss we insert below. So we replace with out-of-place
            # ones here.
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool_{}'.format(i)
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(i)
        else:
            raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))

        model.add_module(name, layer)

        if name in style_layers:
            # add style loss:
            target_feature = model(style_img).detach()
            style_loss = StyleLoss(target_feature)
            model.add_module("style_loss_{}".format(i), style_loss)
            style_losses.append(style_loss)

    # now we trim off the layers after the last content and style losses
    for i in range(len(model) - 1, -1, -1):
        #if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
        if  isinstance(model[i], StyleLoss):
            break

    model = model[:(i + 1)]

    return model, style_losses

def get_input_optimizer(input_img):
    # this line to show that input is a parameter that requires a gradient
    optimizer = optim.LBFGS([input_img.requires_grad_()])
    return optimizer

def run_style_transfer(cnn, normalization_mean, normalization_std, style_img, input_img, pix_model, data_inst, data_label, num_steps=50000,
                       style_weight=1000000, content_weight=1):
    """Run the style transfer."""
    print('Building the style transfer model..')
    input_img = input_img.to(device)
    pix_model = copy.deepcopy(pix_model).to(device)
    pix_model.eval()
    vgg_model, style_losses = get_style_model_and_losses(cnn,
        normalization_mean, normalization_std, style_img)
    optimizer = get_input_optimizer(input_img)

    merge_model = 

    print('Optimizing..')
    run = [0]
    while run[0] <= num_steps:

        def closure():
            # correct the values of updated input image
            input_img.data.clamp_(0, 1)
            optimizer.zero_grad()

            generated = pix_model.inference(input_img, data_inst, data_label)
            # save_image(generated[0].unsqueeze(0), 'image_test.png')
            vgg_model(generated[0].unsqueeze(0))
            vgg_model(input_img)

            style_score = 0

            for sl in style_losses:
                style_score += sl.loss

            style_score *= style_weight

            loss = style_score
            loss.backward()

            run[0] += 1
            if run[0] % 50 == 0:
                print("run {}:".format(run))
                print('Style Loss : {:4f}'.format(
                    style_score.item()))
                print()

            return style_score
        optimizer.step(closure)
    # a last correction...
    input_img.data.clamp_(0, 1)
    return input_img



for i, data in enumerate(dataset):

    # save_image(data['label'], 'image_test.png')

    output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std, style_img=style_img, input_img=data['label'], data_inst=data['inst'], data_label=data['label'], num_steps=2000, pix_model=pix_model)

    save_image(output, 'image_test.png')