import torch
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, models
from torch.autograd import Variable
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from skimage.transform import resize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def im_convert(tensor):
image ="cpu").clone().detach()
image = image.numpy().squeeze()
image = image.transpose(1, 2, 0) # B G R
# undo normalizations
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
image = image.clip(0, 1)
return image
def load_image(path, max_size = 400, shape = None):
''' Load and transform image and make sure the image <= 400
pixels in the x-y dims. '''
image ='RGB')
size = max_size if (max(image.size) > max_size) else max(image.size)
if shape is not None:
size = shape
in_transform = transforms.Compose([
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
# discard the transparent, alpha channel(:3) and add the batch dimension
image = in_transform(image)[:3, :, :].unsqueeze(0)
return image
def image_to_tensor(x):
Transforms np.array to torch.Tensor
(W, H) -> (1, 1, W, H)
(W, H, C) -> (1, C, W, H)
(B, W, H, C) -> (B, C, W, H)
if x.ndim == 2:
return torch.Tensor(x).unsqueeze(0).unsqueeze(0)
if x.ndim == 3:
return torch.Tensor(x.transpose(2, 0, 1)).unsqueeze(0)
if x.ndim == 4:
return torch.Tensor(x.transpose(0, 3, 1, 2))
raise RuntimeError("np.array's ndim is out of range 2, 3 or 4.")
def extract_masks(segment):
Extracts the segmentation masks from the segmentated image.
Allowed colors are:
blue, green, black, white, red,
yellow, grey, light_blue, purple.
extracted_colors = []
mask_r = segment[..., 0] < 0.1
mask_g = segment[..., 1] < 0.1
mask_b = segment[..., 2] > 0.9
mask = mask_r & mask_g & mask_b
mask_r = segment[..., 0] < 0.1
mask_g = segment[..., 1] > 0.9
mask_b = segment[..., 2] < 0.1
mask = mask_r & mask_g & mask_b
mask_r = segment[..., 0] < 0.1
mask_g = segment[..., 1] < 0.1
mask_b = segment[..., 2] < 0.1
mask = mask_r & mask_g & mask_b
mask_r = segment[..., 0] > 0.9
mask_g = segment[..., 1] > 0.9
mask_b = segment[..., 2] > 0.9
mask = mask_r & mask_g & mask_b
mask_r = segment[..., 0] > 0.9
mask_g = segment[..., 1] < 0.1
mask_b = segment[..., 2] < 0.1
mask = mask_r & mask_g & mask_b
mask_r = segment[..., 0] > 0.9
mask_g = segment[..., 1] > 0.9
mask_b = segment[..., 2] < 0.1
mask = mask_r & mask_g & mask_b
mask_r = (segment[..., 0] > 0.4) & (segment[..., 0] < 0.6)
mask_g = (segment[..., 1] > 0.4) & (segment[..., 1] < 0.6)
mask_b = (segment[..., 2] > 0.4) & (segment[..., 2] < 0.6)
mask = mask_r & mask_g & mask_b
mask_r = segment[..., 0] < 0.1
mask_g = segment[..., 1] > 0.9
mask_b = segment[..., 2] > 0.9
mask = mask_r & mask_g & mask_b
mask_r = segment[..., 0] > 0.9
mask_g = segment[..., 1] < 0.1
mask_b = segment[..., 2] > 0.9
mask = mask_r & mask_g & mask_b
return extracted_colors
def get_all_masks(path):
Returns the segmentation masks from the segmentated image.
image =
np_image = np.array(image, dtype=np.float) / 255
return extract_masks(np_image)
def is_nonzero(mask, thrs=0.01):
Checks segmentation mask is dense.
return np.sum(mask) / mask.size > thrs
def get_masks(path_style, path_content):
Returns the meaningful segmentation masks.
Avoides "orphan semantic labels" problem.
masks_style = get_all_masks(path_style)
masks_content = get_all_masks(path_content)
non_zero_masks = [
is_nonzero(mask_c) and is_nonzero(mask_s)
for mask_c, mask_s in zip(masks_content, masks_style)
masks_style = [mask for mask, cond in zip(masks_style, non_zero_masks) if cond]
masks_content = [mask for mask, cond in zip(masks_content, non_zero_masks) if cond]
return masks_style, masks_content
def resize_masks(masks_style, masks_content, size):
Resizes masks to given size.
resize_mask = lambda mask: resize(mask, size, mode="reflect")
masks_style = [resize_mask(mask) for mask in masks_style]
masks_content = [resize_mask(mask) for mask in masks_content]
return masks_style, masks_content
def masks_to_tensor(masks_style, masks_content):
Transforms masks to torch.Tensor from np.array.
masks_style = [image_to_tensor(mask) for mask in masks_style]
masks_content = [image_to_tensor(mask) for mask in masks_content]
return masks_style, masks_content
def masks_loader(path_style, path_content, size):
Loads masks.
style_masks, content_masks = get_masks(path_style, path_content)
style_masks, content_masks = resize_masks(style_masks, content_masks, size)
style_masks, content_masks = masks_to_tensor(style_masks, content_masks)
return style_masks, content_masks
def get_features(image, model, layers=None):
""" Run an image forward through a model and get the features for
a set of layers. Default layers are for VGGNet matching Gatys et al (2016)
if layers is None:
layers = {
features = {}
x = image
# model._modules is a dictionary holding each module in the model
for name, layer in model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
def gram_matrix(tensor):
""" Calculate the Gram Matrix of a given tensor
Gram Matrix:
# get the batch_size, depth, height, and width of the Tensor
_, d, h, w = tensor.size()
# reshape so we're multiplying the features for each channel
tensor = tensor.view(d, h * w)
# calculate the gram matrix
gram =, tensor.t())
return gram
def PhotorealismRegularization(image, target):
L = compute_laplacian(image)
grad =, 3))
loss = np.sum(target.reshape(3, -1).dot(grad))
return torch.tensor(loss), 2. * grad.reshape(*target.shape)
import io
style_path = io.BytesIO(uploaded['tar4.png'])
content_path = io.BytesIO(uploaded['in4.png'])
style_mask_path = io.BytesIO(uploaded['seg_tar4.png'])
content_mask_path = io.BytesIO(uploaded['seg_in4.png'])
vgg = models.vgg19(pretrained = True).features
for param in vgg.parameters():
# load content and style
content = load_image(content_path).to(device)
# resize style to match content
style = load_image(style_path, shape = content.shape[-2:]).to(device)
# load the masks
style_masks, content_masks = masks_loader(
# get content and style features
content_features = get_features(content, vgg)
style_features = get_features(style, vgg)
# get the gram matrices
style_grams = {layer : gram_matrix(style_features[layer]) for layer in style_features}
# get the target image
target = content.clone().requires_grad_(True).to(device)
style_weights = {'conv1_1': 0.2,
'conv2_1': 0.2,
'conv3_1': 0.2,
'conv4_1': 0.2,
'conv5_1': 0.2}
content_weight = 1 # Alpha
style_weight = 1e2 # Gamma
photorealismRegularization_weight = 1e4 # lambda
# constants
content_layer = 'conv4_2'
# for displaying the target image, intermittently
show_every = 400
# iteration hyperparameters
optimizer = optim.Adam([target], lr=0.003)
steps = 2000 # decide how many iterations to update your image (5000)
for ii in range(1, steps+1):
# get the features from your target image
target_features = get_features(target, vgg)
# content loss
content_loss = torch.mean((target_features[content_layer] - content_features[content_layer])**2)
# photorealistic regularization loss
reg_loss, reg_grad = PhotorealismRegularization(content.detach().to("cpu").numpy().transpose(0, 2, 3, 1).squeeze().clip(0, 1), target.detach().to("cpu").numpy().transpose(0, 2, 3, 1).squeeze().clip(0, 1))
reg_grad_tensor = image_to_tensor(reg_grad).to(device)
#target.grad += reg_grad_tensor
# augmented style loss
#style_loss = augmented_style_loss(target, vgg, content_masks, style_masks)
layers = {
style_loss = 0
x = target.clone().to(device)
# model._modules is a dictionary holding each module in the model
for name, layer in vgg._modules.items():
x = layer(x)
if isinstance(layer, torch.nn.MaxPool2d):
style_masks = [layer(mask) for mask in style_masks]
content_masks = [layer(mask) for mask in content_masks]
if name in layers:
target_gram = [gram_matrix(x * Variable(, requires_grad = True)) for mask in content_masks]
style_gram = [gram_matrix(x * Variable(, requires_grad = True)) for mask in style_masks]
layer_loss = style_weights[layers[name]] * sum(F.mse_loss(t_gram, s_gram) for t_gram, s_gram in zip(target_gram, style_gram))
target_feature = target_features[layers[name]]
_, d, h, w = target_feature.shape
layer_loss = layer_loss / (d * h * w)
style_loss = style_loss + layer_loss
# total loss
total_loss = content_weight * content_loss + style_weight * style_loss + photorealismRegularization_weight * reg_loss
# update target image
# display intermediate images and print the loss
if ii % show_every == 0:
print('Total loss: ', total_loss.item())