import torch
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, models
from torch.autograd import Variable
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from skimage.transform import resize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def im_convert(tensor):
image = tensor.to("cpu").clone().detach()
image = image.numpy().squeeze()
image = image.transpose(1, 2, 0) # B G R
# undo normalizations
image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
image = image.clip(0, 1)
return image
def load_image(path, max_size = 400, shape = None):
''' Load and transform image and make sure the image <= 400
pixels in the x-y dims. '''
image = Image.open(path).convert('RGB')
size = max_size if (max(image.size) > max_size) else max(image.size)
if shape is not None:
size = shape
in_transform = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
])
# discard the transparent, alpha channel(:3) and add the batch dimension
image = in_transform(image)[:3, :, :].unsqueeze(0)
return image
def image_to_tensor(x):
"""
Transforms np.array to torch.Tensor
(W, H) -> (1, 1, W, H)
(W, H, C) -> (1, C, W, H)
(B, W, H, C) -> (B, C, W, H)
"""
if x.ndim == 2:
return torch.Tensor(x).unsqueeze(0).unsqueeze(0)
if x.ndim == 3:
return torch.Tensor(x.transpose(2, 0, 1)).unsqueeze(0)
if x.ndim == 4:
return torch.Tensor(x.transpose(0, 3, 1, 2))
raise RuntimeError("np.array's ndim is out of range 2, 3 or 4.")
def extract_masks(segment):
"""
Extracts the segmentation masks from the segmentated image.
Allowed colors are:
blue, green, black, white, red,
yellow, grey, light_blue, purple.
"""
extracted_colors = []
# BLUE
mask_r = segment[..., 0] < 0.1
mask_g = segment[..., 1] < 0.1
mask_b = segment[..., 2] > 0.9
mask = mask_r & mask_g & mask_b
extracted_colors.append(mask)
# GREEN
mask_r = segment[..., 0] < 0.1
mask_g = segment[..., 1] > 0.9
mask_b = segment[..., 2] < 0.1
mask = mask_r & mask_g & mask_b
extracted_colors.append(mask)
# BLACK
mask_r = segment[..., 0] < 0.1
mask_g = segment[..., 1] < 0.1
mask_b = segment[..., 2] < 0.1
mask = mask_r & mask_g & mask_b
extracted_colors.append(mask)
# WHITE
mask_r = segment[..., 0] > 0.9
mask_g = segment[..., 1] > 0.9
mask_b = segment[..., 2] > 0.9
mask = mask_r & mask_g & mask_b
extracted_colors.append(mask)
# RED
mask_r = segment[..., 0] > 0.9
mask_g = segment[..., 1] < 0.1
mask_b = segment[..., 2] < 0.1
mask = mask_r & mask_g & mask_b
extracted_colors.append(mask)
# YELLOW
mask_r = segment[..., 0] > 0.9
mask_g = segment[..., 1] > 0.9
mask_b = segment[..., 2] < 0.1
mask = mask_r & mask_g & mask_b
extracted_colors.append(mask)
# GREY
mask_r = (segment[..., 0] > 0.4) & (segment[..., 0] < 0.6)
mask_g = (segment[..., 1] > 0.4) & (segment[..., 1] < 0.6)
mask_b = (segment[..., 2] > 0.4) & (segment[..., 2] < 0.6)
mask = mask_r & mask_g & mask_b
extracted_colors.append(mask)
# LIGHT_BLUE
mask_r = segment[..., 0] < 0.1
mask_g = segment[..., 1] > 0.9
mask_b = segment[..., 2] > 0.9
mask = mask_r & mask_g & mask_b
extracted_colors.append(mask)
# PURPLE
mask_r = segment[..., 0] > 0.9
mask_g = segment[..., 1] < 0.1
mask_b = segment[..., 2] > 0.9
mask = mask_r & mask_g & mask_b
extracted_colors.append(mask)
return extracted_colors
def get_all_masks(path):
"""
Returns the segmentation masks from the segmentated image.
"""
image = Image.open(path)
np_image = np.array(image, dtype=np.float) / 255
return extract_masks(np_image)
def is_nonzero(mask, thrs=0.01):
"""
Checks segmentation mask is dense.
"""
return np.sum(mask) / mask.size > thrs
def get_masks(path_style, path_content):
"""
Returns the meaningful segmentation masks.
Avoides "orphan semantic labels" problem.
"""
masks_style = get_all_masks(path_style)
masks_content = get_all_masks(path_content)
non_zero_masks = [
is_nonzero(mask_c) and is_nonzero(mask_s)
for mask_c, mask_s in zip(masks_content, masks_style)
]
masks_style = [mask for mask, cond in zip(masks_style, non_zero_masks) if cond]
masks_content = [mask for mask, cond in zip(masks_content, non_zero_masks) if cond]
return masks_style, masks_content
def resize_masks(masks_style, masks_content, size):
"""
Resizes masks to given size.
"""
resize_mask = lambda mask: resize(mask, size, mode="reflect")
masks_style = [resize_mask(mask) for mask in masks_style]
masks_content = [resize_mask(mask) for mask in masks_content]
return masks_style, masks_content
def masks_to_tensor(masks_style, masks_content):
"""
Transforms masks to torch.Tensor from np.array.
"""
masks_style = [image_to_tensor(mask) for mask in masks_style]
masks_content = [image_to_tensor(mask) for mask in masks_content]
return masks_style, masks_content
def masks_loader(path_style, path_content, size):
"""
Loads masks.
"""
style_masks, content_masks = get_masks(path_style, path_content)
style_masks, content_masks = resize_masks(style_masks, content_masks, size)
style_masks, content_masks = masks_to_tensor(style_masks, content_masks)
return style_masks, content_masks
def get_features(image, model, layers=None):
""" Run an image forward through a model and get the features for
a set of layers. Default layers are for VGGNet matching Gatys et al (2016)
"""
if layers is None:
layers = {
'0':'conv1_1',
'5':'conv2_1',
'10':'conv3_1',
'19':'conv4_1',
'21':'conv4_2',
'28':'conv5_1'
}
features = {}
x = image
# model._modules is a dictionary holding each module in the model
for name, layer in model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
def gram_matrix(tensor):
""" Calculate the Gram Matrix of a given tensor
Gram Matrix: https://en.wikipedia.org/wiki/Gramian_matrix
"""
# get the batch_size, depth, height, and width of the Tensor
_, d, h, w = tensor.size()
# reshape so we're multiplying the features for each channel
tensor = tensor.view(d, h * w)
# calculate the gram matrix
gram = torch.mm(tensor, tensor.t())
return gram
def PhotorealismRegularization(image, target):
L = compute_laplacian(image)
grad = L.dot(target.reshape(-1, 3))
loss = np.sum(target.reshape(3, -1).dot(grad))
return torch.tensor(loss), 2. * grad.reshape(*target.shape)
import io
style_path = io.BytesIO(uploaded['tar4.png'])
content_path = io.BytesIO(uploaded['in4.png'])
style_mask_path = io.BytesIO(uploaded['seg_tar4.png'])
content_mask_path = io.BytesIO(uploaded['seg_in4.png'])
vgg = models.vgg19(pretrained = True).features
for param in vgg.parameters():
param.requires_grad_(False)
vgg.to(device)
# load content and style
content = load_image(content_path).to(device)
# resize style to match content
style = load_image(style_path, shape = content.shape[-2:]).to(device)
# load the masks
style_masks, content_masks = masks_loader(
style_mask_path,
content_mask_path,
content.shape[-2:])
# get content and style features
content_features = get_features(content, vgg)
style_features = get_features(style, vgg)
# get the gram matrices
style_grams = {layer : gram_matrix(style_features[layer]) for layer in style_features}
# get the target image
target = content.clone().requires_grad_(True).to(device)
style_weights = {'conv1_1': 0.2,
'conv2_1': 0.2,
'conv3_1': 0.2,
'conv4_1': 0.2,
'conv5_1': 0.2}
content_weight = 1 # Alpha
style_weight = 1e2 # Gamma
photorealismRegularization_weight = 1e4 # lambda
# constants
content_layer = 'conv4_2'
# for displaying the target image, intermittently
show_every = 400
# iteration hyperparameters
optimizer = optim.Adam([target], lr=0.003)
steps = 2000 # decide how many iterations to update your image (5000)
for ii in range(1, steps+1):
# get the features from your target image
target_features = get_features(target, vgg)
# content loss
content_loss = torch.mean((target_features[content_layer] - content_features[content_layer])**2)
# photorealistic regularization loss
reg_loss, reg_grad = PhotorealismRegularization(content.detach().to("cpu").numpy().transpose(0, 2, 3, 1).squeeze().clip(0, 1), target.detach().to("cpu").numpy().transpose(0, 2, 3, 1).squeeze().clip(0, 1))
reg_grad_tensor = image_to_tensor(reg_grad).to(device)
#target.grad += reg_grad_tensor
# augmented style loss
#style_loss = augmented_style_loss(target, vgg, content_masks, style_masks)
layers = {
'0':'conv1_1',
'5':'conv2_1',
'10':'conv3_1',
'19':'conv4_1',
'28':'conv5_1'
}
style_loss = 0
x = target.clone().to(device)
# model._modules is a dictionary holding each module in the model
for name, layer in vgg._modules.items():
x = layer(x)
if isinstance(layer, torch.nn.MaxPool2d):
style_masks = [layer(mask) for mask in style_masks]
content_masks = [layer(mask) for mask in content_masks]
if name in layers:
target_gram = [gram_matrix(x * Variable(mask.to(device), requires_grad = True)) for mask in content_masks]
style_gram = [gram_matrix(x * Variable(mask.to(device), requires_grad = True)) for mask in style_masks]
layer_loss = style_weights[layers[name]] * sum(F.mse_loss(t_gram, s_gram) for t_gram, s_gram in zip(target_gram, style_gram))
target_feature = target_features[layers[name]]
_, d, h, w = target_feature.shape
layer_loss = layer_loss / (d * h * w)
style_loss = style_loss + layer_loss
# total loss
total_loss = content_weight * content_loss + style_weight * style_loss + photorealismRegularization_weight * reg_loss
# update target image
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
# display intermediate images and print the loss
if ii % show_every == 0:
print('Total loss: ', total_loss.item())
plt.imshow(im_convert(target))
plt.show()