Why isn't my simple model isn't learning the right information?

I have created a simple CPPN model and the goal is to make learn to create outputs that will maximize the activations of a chosen layer in Torchvision’s pre-trained VGG-19 model. Example diagram of what I’m trying to do (ignore the channel part):

I have thus far been unable to get the CPPN model to learn the right information, and it seems to just create swirling shapes.

The code for training the CPPN model can be found below:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import models


# Function to maximize CNN activations
def dream_cppn(net, image_size, iterations, lr, use_device):
    # Create instance of CPPN
    img_cppn = CPPN_Conv(size=image_size, num_channels=24, num_layers=8, use_device=use_device)
    print('CPPN Params Count', sum(p.numel() for p in img_cppn.parameters() if p.requires_grad))	


    # Setup optimizer to optimize CPPN instance
    optimizer = torch.optim.Adam(list(img_cppn.parameters()), lr=lr)

    # Training loop
    for i in range(iterations):
        optimizer.zero_grad()
        img = img_cppn() # Create RGB image with CPPN
        out = net(img) # Get activations from CNN

        loss = -out.pow(2).sum().sqrt()
        loss.backward()

        # Uncomment to save iterations		
        #simple_deprocess(img.detach(), 'out_'+str(i)+'.png')

        print('Iteration', str(i+1), 'Loss', str(loss.item()))

        optimizer.step()
    return img


# Activation function for CPPN
class CompositeActivation(nn.Module):
    def forward(self, input):
        input = torch.atan(input)
        return torch.cat([input / 0.67, (input * input) / 0.6], 1)


# Compositional pattern-producing network (CPPN) with Conv2d layers
class CPPN_Conv(nn.Module):

    def __init__(self, size=(405, 512), num_channels=24, num_layers=8, activ_func=CompositeActivation(), use_device='cpu'):
        super(CPPN_Conv, self).__init__()
        self.input_size = size
        self.n_channels = num_channels
        self.net = self.create_net(num_channels, num_layers, activ_func, use_device) 
        self.cppn_input = self.create_input(use_device)
	
    # Create CPPN (X,Y) --> (R,G,B)	
    def create_net(self, num_channels, num_layers, activ_func, use_device, bias=True, instance_norm=True, affine=False):
        net = nn.Sequential()
        net.add_module(str(len(net)), nn.Conv2d(in_channels=2, out_channels=num_channels, kernel_size=1, bias=bias))
        if instance_norm:
            net.add_module(str(len(net)), nn.InstanceNorm2d(num_channels, affine=affine))
        net.add_module(str(len(net)), activ_func)
        for l in range(num_layers - 1):
            net.add_module(str(len(net)), nn.Conv2d(in_channels=num_channels*2, out_channels=num_channels, kernel_size=1, bias=bias))
            if instance_norm:
                 net.add_module(str(len(net)), nn.InstanceNorm2d(num_channels, affine=affine))
            net.add_module(str(len(net)), activ_func)
        net.add_module(str(len(net)), nn.Conv2d(in_channels=num_channels*2, out_channels=3, kernel_size=1, bias=bias))
        net.add_module(str(len(net)), nn.Sigmoid())
        net.apply(self.cppn_normal)
        return net.to(use_device)
		
    # Create X,Y input for CPPN		
    def create_input(self, use_device):
        if type(self.input_size) is not tuple and type(self.input_size) is not list:
            self.input_size = (self.input_size, self.input_size) 
        w = torch.arange(0, self.input_size[1]).to(use_device)
        h = torch.arange(0, self.input_size[0]).to(use_device)
        w_exp = w.unsqueeze(1).expand((self.input_size[1], self.input_size[0])).true_divide(self.input_size[0]) - 0.5
        h_exp = h.unsqueeze(0).expand((self.input_size[1], self.input_size[0])).true_divide(self.input_size[1]) - 0.5
        return torch.stack((w_exp, h_exp), -1).permute(2,1,0).unsqueeze(0)

    # Normalize Conv2d weights
    def cppn_normal(self, l):
        if type(l) == nn.Conv2d:        
            l.weight.data.normal_(0, (1.0/self.n_channels)**(1/2))

    def forward(self):
        return self.net(self.cppn_input)


# Simple deprocess and save image
def simple_deprocess(output_tensor, name):
    output_tensor = output_tensor.squeeze(0).cpu().clamp_(0, 1)
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.cpu()).save(name)


# Main function to setup params and run code
def main():		
    image_size=(512,512)
    iterations = 50
    lr = 0.1
    use_device = 'cuda:0'	
    dream_layer = 37

    # Uncomment if loading model from file
    #cnn = models.vgg19(pretrained=False)
    #cnn.load_state_dict(torch.load('vgg19-dcbb9e9d.pth'))
    cnn = models.vgg19(pretrained=True)
    layers = list(cnn.features.children())
    net = nn.Sequential(*layers[: (dream_layer + 1)]).to(use_device) # Remove layers above target layer
	
    output_tensor = dream_cppn(net, image_size, iterations, lr, use_device)
    simple_deprocess(output_tensor, 'out.png')
	
main()

I am attempting to replicate some seemingly simple results from here: Differentiable Image Parameterizations, but you shouldn’t need to read the article to understand the code. The original results were created using TensorFlow, but I’ve recreated everything in PyTorch, and my CPPN model has the same number of parameters as theirs.

Any help with getting this to work properly would be appreciated! I can answer any questions about the code, but I don’t exactly what I’ve done wrong?

The above code produces results that look more like this:

out

It should produce results like these:


This version of the training function uses a randomly generated tensor as the item being optimized, and works correctly:

# Use random tensor instead of CPPN
def dream_cppn(net, image_size, iterations, lr, use_device):
    # Create random image 
    img = nn.Parameter(torch.randn(1, 3, image_size[0], image_size[1]).to(use_device))       
    # Setup optimizer to optimize random image
    optimizer = torch.optim.Adam([img], lr=lr)

    # Training loop
    for i in range(iterations):
        optimizer.zero_grad()
        out = net(img) # Get activations from CNN

        loss = -out.pow(2).sum().sqrt()
        loss.backward()

        # Uncomment to save iterations		
        #simple_deprocess(img.detach(), 'iters/out_'+str(i)+'.png')

        print('Iteration', str(i+1), 'Loss', str(loss.item()))

        optimizer.step()
    return img

out_r

  • There are structures from the target layer visible.

It seems to work better with non random images:

from PIL import Image
def preprocess(image_name, image_size):
    Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
    Normalize = transforms.Compose([transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[1,1,1])])
    return Normalize(Loader(Image.open(image_name).convert('RGB'))).unsqueeze(0)
def deprocess(output_tensor, name='out.png'):
    Normalize = transforms.Compose([transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1,1,1])])
    output_tensor = Normalize(output_tensor.squeeze(0).cpu())
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.clamp_(0, 1).cpu()).save(name)

def dream_cppn(...
    img = nn.Parameter(preprocess('test.jpg', image_size).to(use_device)) 

out_c

  • lr = 0.015