Back propagation with customize loss function

Hi,

I am building a model for image synthesis. The architecture for the model is shown below.


the objective is to update start_input to minimize the difference between start_output and target_output. (start_output and target_output are the activation for certain layer in ResNet )

the problem is that when I run the code below, the loss value doesn’t decrease, it is oscillating between certain value, and the distance between start_input and target_input is even increasing sometimes.
Could any one help me figure out where went wrong? Any suggestions and help would be appreciated. Thanks in advance.

the code is here


import torchvision.models as models
import torch 
import torch.nn as nn
import logging
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_names, truncated_noise_sample,
                                       save_as_images, display_in_terminal)
import torch.nn.functional as F
import urllib
from PIL import Image
from torchvision import transforms
from numpy import linalg as LA

# load the pretrained weight for resnet and BigGAN
resnet = models.resnet50(pretrained=True)
GAN_model = BigGAN.from_pretrained('biggan-deep-256')

# define class for combine the two models
class MyEnsemble(nn.Module):
    def __init__(self, modelA, modelB):
        super(MyEnsemble, self).__init__()
        self.modelA=modelA
        self.modelB=modelB
    def forward(self,z,class_name, truncation):
        x=self.modelA(z, class_name, truncation)
        # resize the tensor 
        x1=F.interpolate(x,(224,224))
        x2=self.modelB(x1)
        return x1, x2 
truncation = 0.5
m = nn.Softmax(dim=1)
input = torch.randn(1, 1000)
class_vector = m(input)
modelA=GAN_model
modelB=resnet
model=MyEnsemble(modelA, modelB)


def my_loss(target_output,start_output):
    return torch.norm(target_output-start_output)

# generate random input for start point image
target_input=torch.randn(1,128,requires_grad = True)
target_point,target_output=model(target_input,class_vector,truncation)
target_output=target_output.detach()

# generate random input for start point image
start_input=torch.randn(1,128,requires_grad = True)

optimizer=torch.optim.Adam( [start_input,class_vector],lr=0.05)
loss_values=[]
distance_values=[]
model.eval()
for epoch in range(101):
    print (epoch)
    torch.autograd.set_detect_anomaly(True)
    optimizer.zero_grad()
    running_loss=0.0
    running_distance=0.0
    distance=torch.norm(start_input-target_input)
    start_point,start_output=model(start_input,class_vector,truncation)
    loss=my_loss(target_output,start_output)
    print(loss)
    print(distance)
    loss.backward(retain_graph=True)
    print('Epoch-{0} lr: {1}'.format(epoch, optimizer.param_groups[0]['lr']))
    optimizer.step()
    if epoch==100:
        save_as_images(start_point,file_name='result')
    running_loss+=loss.item()
    running_distance+=distance.item()
    loss_values.append(running_loss)
    distance_values.append(running_distance)