Image Texture Transfer paper implementation - Not achieving same results

Hi, I was trying to implement a paper originally implemented in Caffe:

https://github.com/leongatys/DeepTextures/tree/master/DeepImageSynthesis

For some reason, I am not being able to replicate in Pytorch exactly the same results. I have checked my code several times and I found no code error.

If someone else has implemented this paper on Pytorch, could he/she share the code for it?

Otherwise, here I give you my code, in case someone had some feedback to give me.

Thank you a lot for your time! :slight_smile:

from __future__ import print_function

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim

import numpy as np

from PIL import Image
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
import torchvision.models as models

import copy

use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

imsize = 512 if use_cuda else 128  


prep = transforms.Compose([
    transforms.Scale(imsize),
    transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[1, 1, 1])
])


undo = transforms.Compose([
        transforms.Normalize(mean=[-0.485, -0.456, -0.406],
                              std=[1, 1, 1]),
    transforms.ToPILImage(),  # change to PIL image
])


def image_loader(image_name):
    image = Image.open(image_name)
    image = Variable(prep(image))
    image = image.unsqueeze(0)
    return image


def imshow(tensor, title=None, preprocess=True):
    image = tensor.clone().cpu()  
    image = image.view(3, imsize, imsize) 
    image = undo(image) if preprocess else transforms.ToPILImage()(image)

    plt.figure()
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.005)  


def get_min_max(image):

    min= [0,0,0]
    max= [0,0,0]

    for k in range(3):
        A = np.squeeze(np.asarray(image[0,k,:,:].data.cpu().numpy()))
        min[k]=np.amin(A)
        max[k]=np.amax(A)
        min [k]= min[k].astype(type('float', (float,), {}))
        max [k]= max[k].astype(type('float', (float,), {}))
    return min,max


unloader = transforms.ToPILImage()  # reconvert into PIL image
plt.ion()

texture = image_loader("images/pebbles.jpg").type(dtype)
noise = Variable(torch.randn(texture.data.size())).type(dtype)
min,max=get_min_max(texture)

imshow(texture.data, title='Texture')
imshow(noise.data, title='Noise', preprocess=False)


cnn = models.vgg19(pretrained=True).features
cnn = cnn.eval()



class GramMatrix(nn.Module):

    def forward(self, input):
        # get and define parameters
        b, N, h, w = input.size()
        M = h * w
        const = 4

        # compute Gram matrix and normalize
        F = input.view(b, N, M)
        G = torch.bmm(F, F.transpose(1, 2))
        G.div_(M * N * const)
        return G
    
class textureLoss(nn.Module):

    def __init__(self, target, weight):
        super(textureLoss, self).__init__()
        self.target = target.detach() * weight
        self.weight = weight
        self.gram = GramMatrix()
        self.criterion = nn.MSELoss()

    def forward(self, input):
        self.output = input.clone()
        self.G = self.gram(input)
        self.G.mul_(self.weight)
        self.loss = self.criterion(self.G, self.target)
        return self.output

    def backward(self, retain_variables=True):
        self.loss.backward(retain_variables=retain_variables)
        return self.loss


# move it to the GPU if possible:
if use_cuda:
    cnn = cnn.cuda()

texture_layers_default = ['conv_1', 

                          'pool_5',
                          'pool_10', 
                          'pool_19', 
                          'pool_28' 
                         ]


i = 1
for layer in list(cnn):
    if isinstance(layer, nn.Conv2d):
        name = "conv_" + str(i)
    if isinstance(layer, nn.ReLU):
        name = "relu_" + str(i)
    if isinstance(layer, nn.MaxPool2d):
        name = "pool_" + str(i) + '\n'
    i += 1
    print(name)



def get_texture_model_and_losses(cnn, texture_img,
                                 texture_weight,
                                 texture_layers=texture_layers_default):
    cnn = copy.deepcopy(cnn)


    texture_losses = []

    model = nn.Sequential() 
    gram = GramMatrix()  

    if use_cuda:
        model = model.cuda()
        gram = gram.cuda()

    i = 1
    for layer in list(cnn):
        
        if isinstance(layer, nn.Conv2d):
            name = "conv_" + str(i)
            model.add_module(name, layer)
            if name in texture_layers:
                print(name)
                target_feature = model(texture_img).clone()
                target_feature_gram = gram(target_feature)
                texture_loss = textureLoss(target_feature_gram, texture_weight)
                model.add_module("texture_loss_" + str(i), texture_loss)
                texture_losses.append(texture_loss)
            i += 1

        if isinstance(layer, nn.ReLU):
            name = "relu_" + str(i)
            model.add_module(name, layer)
            if name in texture_layers:
                print(name)
                target_feature = model(texture_img).clone()
                target_feature_gram = gram(target_feature)
                texture_loss = textureLoss(target_feature_gram, texture_weight)
                model.add_module("texture_loss_" + str(i), texture_loss)
                texture_losses.append(texture_loss)
            i += 1

        if isinstance(layer, nn.MaxPool2d):
            name = "pool_" + str(i)
            
            model.add_module(name, layer)
            if name in texture_layers:
                print(name)
                target_feature = model(texture_img).clone()
                target_feature_gram = gram(target_feature)
                texture_loss = textureLoss(target_feature_gram, texture_weight)
                model.add_module("texture_loss_" + str(i), texture_loss)
                texture_losses.append(texture_loss)  # ***
            i += 1

    return model, texture_losses


def get_input_param_optimizer(input_img):
    input_param = nn.Parameter(input_img.data)
    optimizer = optim.LBFGS([input_param])
    return input_param, optimizer


def run_style_transfer(cnn, texture_img, input_img, num_steps=400, texture_weight=1e9):
    print('Building the style transfer model..')
    model, texture_losses = get_texture_model_and_losses(
        cnn, texture_img, texture_weight)
    input_param, optimizer = get_input_param_optimizer(input_img)

   print('Optimizing..')
    run = [0]
    while run[0] <= num_steps:

        def closure():
            for k in range(3):
               input_param.data[0,k,:,:] = torch.clamp(input_param.data[0,k,:,:], 
                                                max=max[k],
                                                min=min[k])
            optimizer.zero_grad()
            model(input_param)
            texture_score = 0

            for tl in texture_losses:
                texture_score += tl.backward()

            run[0] += 1
            if run[0] % 50 == 0:
                print("run {}:".format(run))
                print('Style Loss : {:4f} '.format(texture_score.data[0]))
                print()

            return texture_score

        optimizer.step(closure)

    for k in range(3):
       input_param.data[0,k,:,:] = torch.clamp(input_param.data[0,k,:,:], 
                                                max=max[k],
                                                min=min[k])

    return input_param.data



output = run_style_transfer(cnn, texture, noise, num_steps=800, texture_weight = 1e9)

plt.figure()
imshow(output, title='Output Image',preprocess=True)

plt.ioff()
plt.show()

In Stanford CS 231N course assignment 3 (http://cs231n.github.io/assignments2017/assignment3/) you are invited to implement Gatys et al.'s style transfer algorithm (https://arxiv.org/abs/1508.06576). Doing it in PyTorch is one of the possibilities (the other one is TensorFlow) and they provide an ipython notebook with some initial code and correctness checking routines. You may find it helpful.

1 Like