Hi I’m working on translating some style transfer Torch code to PyTorch and I’m running into some issues probably because I’m not using autograd correctly. I’m able to run all the way through the building of my network as well as optimization steps but the loss never decreases (it just outputs the same thing for every iteration). I’m not particularly experienced with Torch and even less so with PyTorch so chances are I’m missing something obvious.
I’ve built up my network (a frozen vgg19) into an nn.Sequential that looks like this:
Net
Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): StyleLoss(
(gram): GramMatrix()
(mse): MSELoss()
)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): ReLU()
(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU()
(8): StyleLoss(
(gram): GramMatrix()
(mse): MSELoss()
)
(9): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(10): ReLU()
(11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(12): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU()
(14): StyleLoss(
(gram): GramMatrix()
(mse): MSELoss()
)
(15): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(16): ReLU()
(17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU()
(19): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU()
(21): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(22): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(23): ReLU()
(24): StyleLoss(
(gram): GramMatrix()
(mse): MSELoss()
)
(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(26): ReLU()
(27): ContentLoss(
(mse): MSELoss()
)
(28): ReLU()
(29): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(30): ReLU()
(31): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(33): ReLU()
(34): StyleLoss(
(gram): GramMatrix()
(mse): MSELoss()
)
)
With ContentLoss and StyleLoss defined as follows:
ContentLoss
class ContentLoss(torch.nn.Module):
def __init__(self, strength, target, normalize):
super(ContentLoss, self).__init__()
self.strength = strength
self.target = target
self.normalize = normalize
self.loss = 0
self.mse = torch.nn.MSELoss()
def forward(self, input):
print(input.shape)
print(self.target.shape)
print(input.nelement())
print(self.target.nelement())
if input.nelement() == self.target.nelement():
self.loss = self.mse.forward(input, self.target) * self.strength
else:
print('WARNING: Skipping content loss')
output = input
return output
def backward(self, input, grad_output):
if input.nelement() == self.target.nelement():
grad_input = self.mse.backward(input, self.target)
if self.normalize:
grad_input.div(torch.norm(grad_input, 1) + 1e-8)
grad_input.mul(self.strength)
grad_input.add(grad_output)
return grad_input
Style Loss
class GramMatrix(torch.nn.Module):
def forward(self, input):
a, b, c, d = input.shape # a=batch size(=1)
features = input.contiguous().view(a * b, c * d) # resise F_XL into \hat F_XL
G = torch.mm(features, features.t()).float() # compute the gram product
return G.div(a * b * c * d)
class StyleLoss(torch.nn.Module):
def __init__(self, strength, target, normalize):
super(StyleLoss, self).__init__()
self.normalize = normalize
self.strength = strength
self.target = target
self.loss = 0
self.gram = GramMatrix()
self.G = None
self.mse = torch.nn.MSELoss()
def forward(self, input):
self.G = self.gram.forward(input)
self.G.div(input.nelement())
self.loss = self.mse.forward(self.G, self.target)
self.loss = self.loss * self.strength
output = input
return output
def backward(self, input, grad_output):
dG = self.mse.backward(self.G, self.target)
dG.div(input.nelement())
grad_input = self.gram.backward(input, dG)
if self.normalize:
grad_input.div(torch.norm(grad_input, 1) + 1e-8)
grad_input.mul(self.strength)
grad_input.add(grad_output)
return grad_input
And then finally I’m trying to run my optimization like this:
y = net.forward(img)
dy = torch.zeros(y.shape)
def closure():
# optimizer.zero_grad()
net.forward(img) #Torch code uses x for img here
torch.autograd.backward(img, dy) # and here calls net.backwards instead of autograd
loss = 0
for mod in content_losses: loss += mod.loss
for mod in temporal_losses: loss += mod.loss
for mod in style_losses: loss += mod.loss
# loss.requires_grad_(True)
# loss.backward()
print(loss.item())
return loss
# Run optimization.
optimizer = torch.optim.LBFGS([img.requires_grad_()], lr=args.learning_rate, max_iter=args.num_iterations, tolerance_change=args.tol_loss_relative)
for iter in range(args.num_iterations): # this for loop is weird to me as I thought LBFGS handled this internally with the max_iter parameter...
optimizer.step(closure)
img_out = np.moveaxis(img.detach().squeeze().numpy(),0,-1)
skimage.io.imsave(args.img_filename.format(0, 1), img_out)
The Torch code I’m following feeds the image into the closure code as x and then substitutes that for the img variable within the closure definition (see comments in closure()), however I wasn’t able to get that working using step() as I needed to give it the reference to the function.
I’ve also tried to call backward() on the loss value directly in closure() as can be seen in the commented out lines near the bottom.
Either way the loss ends up printing the same value for every iteration and the final optimized picture does not look stylized (which I think is because only one iteration of optimization is being run).
How can I make sure that the image is optimized correctly using the style & content losses I’ve defined?