Hi,
I am building a model for image synthesis. The architecture for the model is shown below.
the objective is to update start_input to minimize the difference between start_output and target_output. (start_output and target_output are the activation for certain layer in ResNet )
the problem is that when I run the code below, the loss value doesn’t decrease, it is oscillating between certain value, and the distance between start_input and target_input is even increasing sometimes.
Could any one help me figure out where went wrong? Any suggestions and help would be appreciated. Thanks in advance.
the code is here
import torchvision.models as models
import torch
import torch.nn as nn
import logging
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_names, truncated_noise_sample,
save_as_images, display_in_terminal)
import torch.nn.functional as F
import urllib
from PIL import Image
from torchvision import transforms
from numpy import linalg as LA
# load the pretrained weight for resnet and BigGAN
resnet = models.resnet50(pretrained=True)
GAN_model = BigGAN.from_pretrained('biggan-deep-256')
# define class for combine the two models
class MyEnsemble(nn.Module):
def __init__(self, modelA, modelB):
super(MyEnsemble, self).__init__()
self.modelA=modelA
self.modelB=modelB
def forward(self,z,class_name, truncation):
x=self.modelA(z, class_name, truncation)
# resize the tensor
x1=F.interpolate(x,(224,224))
x2=self.modelB(x1)
return x1, x2
truncation = 0.5
m = nn.Softmax(dim=1)
input = torch.randn(1, 1000)
class_vector = m(input)
modelA=GAN_model
modelB=resnet
model=MyEnsemble(modelA, modelB)
def my_loss(target_output,start_output):
return torch.norm(target_output-start_output)
# generate random input for start point image
target_input=torch.randn(1,128,requires_grad = True)
target_point,target_output=model(target_input,class_vector,truncation)
target_output=target_output.detach()
# generate random input for start point image
start_input=torch.randn(1,128,requires_grad = True)
optimizer=torch.optim.Adam( [start_input,class_vector],lr=0.05)
loss_values=[]
distance_values=[]
model.eval()
for epoch in range(101):
print (epoch)
torch.autograd.set_detect_anomaly(True)
optimizer.zero_grad()
running_loss=0.0
running_distance=0.0
distance=torch.norm(start_input-target_input)
start_point,start_output=model(start_input,class_vector,truncation)
loss=my_loss(target_output,start_output)
print(loss)
print(distance)
loss.backward(retain_graph=True)
print('Epoch-{0} lr: {1}'.format(epoch, optimizer.param_groups[0]['lr']))
optimizer.step()
if epoch==100:
save_as_images(start_point,file_name='result')
running_loss+=loss.item()
running_distance+=distance.item()
loss_values.append(running_loss)
distance_values.append(running_distance)