Hi guys, I encounter an odd problem when implementing the neural style transfer.
Here is my code:
class Normalization(nn.Module):
def __init__(self, device):
super(Normalization, self).__init__()
# .view the mean and std to make them [C x 1 x 1] so that they can
# directly work with image Tensor of shape [B x C x H x W].
# B is batch size. C is number of channels. H is height and W is width.
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
self.mean = torch.tensor(mean).view(-1, 1, 1).to(device)
self.std = torch.tensor(std).view(-1, 1, 1).to(device)
def forward(self, img):
# normalize img
return (img - self.mean) / self.std
def gram_matrix(input):
a, b, c, d = input.size() # a = batch size( = n)
# b = number of feature maps
# (c, d) = dimensions of a f. map (N = c * d)
features = input.view(a, b, c * d) # resise F_XL into \hat F_XL
# G = torch.mm(features, features.t()) # compute the gram product
G = torch.einsum('ijk,ikt->ijt', features, features.transpose(2, 1)) # compute the gram product
# we 'normalize' the values of the gram matrix
# by dividing by the number of element in each feature maps.
return G.div(b * c * d)
def calc_style_loss(input, target):
G = gram_matrix(input)
return F.mse_loss(G, target)
def calc_content_loss(input, target):
# we 'detach' the target content from the tree used
# to dynamically compute the gradient: this is a stated value,
# not a variable. Otherwise the forward method of the criterion
# will throw an error.
return F.mse_loss(input, target)
class TransferModel(nn.Module):
def __init__(self, pretrained, device, style_weight=1e6, content_weight=1.):
super(TransferModel, self).__init__()
content_layers = ['conv_4']
style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
self.style_losses = []
self.content_losses = []
self.style_weight = style_weight
self.content_weight = content_weight
self.style_feature_maps, self.content_feature_maps = [], []
self.model = self.__build_network__(device, pretrained, content_layers, style_layers).eval()
def style_hook(self, module, input, output):
'''
Storing feature maps in calculating style loss
'''
if self.detach is True:
self.style_feature_maps.append(output.clone().detach())
else:
self.style_feature_maps.append(output)
def content_hook(self, module, input, output):
'''
Storing feature maps in calculating content loss
'''
if self.detach is True:
self.style_feature_maps.append(output.clone().detach())
else:
self.style_feature_maps.append(output)
def __build_network__(self, device, pretrained, content_layers, style_layers):
# assuming that cnn is a nn.Sequential, so we make a new nn.Sequential
# to put in modules that are supposed to be activated sequentially
normalization = Normalization(device=device)
model = nn.Sequential(normalization).to(device)
vgg = VGG19(pretrained=pretrained).features.to(device).eval()
i = 0 # increment every time we see a conv
for layer in vgg.children():
if isinstance(layer, nn.Conv2d):
i += 1
name = 'conv_{}'.format(i)
elif isinstance(layer, nn.ReLU):
name = 'relu_{}'.format(i)
# The in-place version doesn't play very nicely with the ContentLoss
# and StyleLoss we insert below. So we replace with out-of-place ones here.
layer = nn.ReLU(inplace=False)
elif isinstance(layer, nn.MaxPool2d):
name = 'pool_{}'.format(i)
elif isinstance(layer, nn.BatchNorm2d):
name = 'bn_{}'.format(i)
else:
raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
n_layer = copy.deepcopy(layer)
model.add_module(name, n_layer)
if name in content_layers:
# add content hook
n_layer.register_forward_hook(self.content_hook)
if name in style_layers:
# add style hook
n_layer.register_forward_hook(self.style_hook)
if name == 'conv_5':
break
del vgg
return model
@torch.no_grad()
def set_target(self, content_img, style_img):
self.detach = True
self.content_feature_maps, self.style_feature_maps = [], []
self.model(content_img)
self.target_content_feature = [feature for feature in self.content_feature_maps]
self.content_feature_maps, self.style_feature_maps = [], []
self.model(style_img)
self.target_style_feature = [gram_matrix(input) for input in self.style_feature_maps]
def forward(self, x):
self.detach = False
self.content_feature_maps, self.style_feature_maps = [], []
self.model(x)
self.content_losses, self.style_losses = [], []
for feature, target_feature in zip(self.content_feature_maps, self.target_content_feature):
self.content_losses += [calc_content_loss(feature, target_feature)]
for feature, target_feature in zip(self.style_feature_maps, self.target_style_feature):
self.style_losses += [calc_style_loss(feature, target_feature)]
style_score, content_score = 0.0, 0.0
for sl in self.style_losses:
style_score += sl
for cl in self.content_losses:
content_score += cl
style_score *= self.style_weight
content_score *= self.content_weight
return style_score + content_score
I utilize VGG
as my backbone, and write a func set_target
to set the content image and the style image. The optimization process is described as below:
# below line to show that input is a parameter that requires a gradient
optimizer = optim.LBFGS([optim_img.requires_grad_()])
transfer_model.set_target(content_img, style_img)
_iter = 0
style_transfer = 0.
while _iter < cfg.style.max_iteres:
def closure():
nonlocal optim_img
optim_img.data.clamp_(0, 1) # correct the values of updated input image
optimizer.zero_grad()
loss = transfer_model(optim_img)
loss.backward()
return loss
optimizer.step(closure)
_iter += 1
However, this code can only run 10 rounds, and the error (OOM) happens.
So, can anyone help me solve this problem?