I have the following codes
# Per essere più puliti definire solo la loss, poi il rischio empirico si calcola separatamente
from torch import nn
class AdversarialLoss(nn.Module):
def __init__(self, discriminator):
super(AdversarialLoss, self).__init__()
self.discriminator = discriminator
def forward(self, output):
# calcolo le probabilita' per le diverse immagini upscaled
return -torch.log(self.discriminator(output))
class VggLoss(nn.Module):
def __init__(self, vgg19_model):
super(VggLoss, self).__init__()
self.vgg19_model = vgg19_model
self.mse_loss = MSELoss()
def get_feature_maps(self, module, input, output):
self.feature_maps = output
def forward(self, output, target):
target_layer = self.vgg19_model.features[7]
self.feature_maps = None
hook_handle = target_layer.register_forward_hook(self.get_feature_maps)
self.vgg19_model(output)
vgg_feature_map_upscaled_img = self.feature_maps
self.vgg19_model(target)
vgg_feature_map_high_res_img = self.feature_maps
hook_handle.remove()
return self.mse_loss(vgg_feature_map_upscaled_img, vgg_feature_map_high_res_img)
# N.B. : il problema sembrerebbe essere tra discriminator loss e adversarial loss
class DiscriminatorLoss(nn.Module):
def __init__(self, discriminator):
super(DiscriminatorLoss, self).__init__()
self.discriminator = discriminator
def forward(self, output, target):
probability_high_res_img = self.discriminator(target)
probability_upscaled_img = 1 - self.discriminator(output)
return torch.log(probability_high_res_img) + torch.log(probability_upscaled_img)
from torch.optim import Adam
def train_gan(num_of_iter, batch_size, device_type, data_loader, generator, discriminator, vgg19_model):
optimizer_gen = Adam(generator.parameters(), lr=0.0001)
optimizer_disc = Adam(discriminator.parameters(), lr=0.0001)
iterator = iter(data_loader)
# sposto i modelli sulla GPU. N.B. : se la memoria viene ecceduta capire quale dei modelli non mettere in memoria
generator.to(device_type)
discriminator.to(device_type)
vgg19_model.to(device_type)
vgg_loss = VggLoss(vgg19_model)
adversarial_loss = AdversarialLoss(discriminator)
discriminator_loss = DiscriminatorLoss(discriminator)
for iteration in range(num_of_iter):
# resetto i gradienti
optimizer_gen.zero_grad()
optimizer_disc.zero_grad()
# prendo il prossimo batch
batch, id = next(iterator)
low_res_imgs = batch['low_res'].to(device_type)
high_res_imgs = batch['high_res'].to(device_type)
# genero le immagini ridimensionate
upscaled_images = generator(low_res_imgs)
# valuto la loss avversaria
loss_adv = adversarial_loss(upscaled_images[0].unsqueeze(0))
print("before adversarial loss ", iteration)
for i in range(1, len(upscaled_images)):
loss_adv += adversarial_loss(upscaled_images[i].unsqueeze(0))
print(loss_adv)
print("after adversarial loss ", iteration)
# valuto la loss vgg
loss_vgg = vgg_loss(upscaled_images[0].unsqueeze(0), high_res_imgs[0].unsqueeze(0))
print("before vgg loss ", iteration)
for i in range(1, len(high_res_imgs)):
loss_vgg += vgg_loss(upscaled_images[i].unsqueeze(0), high_res_imgs[i].unsqueeze(0))
print("after vgg loss ", iteration)
# calcolo la loss finale
losses_sum = torch.div((0.006 * loss_vgg + 0.001 * loss_adv), batch_size)
# calcolo il gradiente
losses_sum.backward()
#print(high_res_imgs[0].grad)
# aggiorno i parametri del generatore
optimizer_gen.step()
print("print before loss disc ", iteration)
loss_disc = discriminator_loss(upscaled_images[0].unsqueeze(0), high_res_imgs[0].unsqueeze(0))
print("a")
for i in range(1, len(upscaled_images)):
loss_disc += discriminator_loss(upscaled_images[i].unsqueeze(0), high_res_imgs[i].unsqueeze(0))
print(loss_disc)
print("b")
# meno in maniera tale che quando si discende il gradiente in realtà lo si massimizza
loss_disc = torch.div(-1 * loss_disc, batch_size)
print("c")
# devo ascendere il gradiente
loss_disc.backward()
print("d")
# aggiorno i parametri del discriminatore
optimizer_disc.step()
print("e")
print("print after loss disc ", iteration)
if iteration % 3 == 0:
print("at iteration ", iteration)
transform = T.ToPILImage()
upscaled_img = upscaled_images[0][0]
high_res_img = high_res_imgs[0][0]
img = transform(upscaled_img)
img.show()
img = transform(high_res_img)
img.show()
When calling the methods with the following code
disc = Discriminator()
num_of_iter = 10000
batch_size = 16
train_gan(num_of_iter, batch_size, dev, data_loader, generator, disc, vgg19_model)
I get the following error : Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
Not sure why it’s happening since I’m not calling backward on the same loss two times.