Hello.
I am trying to compute a custom loss function, but I can’t seem to be able to train the network, even if I call loss.backward(retain_graph=True)
Here is the training code:
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, image in enumerate(train_loader):
optimizer.zero_grad()
data = image[0].to(device)
output = model(data).to(device)
#print(VGG_model(output).shape)
data_loss = VGG_model(data)
output_loss = VGG_model(output)
content_loss = F.mse_loss(data_loss[-1], output_loss[-1])
style_loss = 0
style_loss += 0.25 * (F.mse_loss(torch.mean(img_stil_loss[0]), torch.mean(output_loss[0])) + F.mse_loss(torch.std(img_stil_loss[0]), torch.std(output_loss[0])))
style_loss += 0.25 * (F.mse_loss(torch.mean(img_stil_loss[1]), torch.mean(output_loss[1])) + F.mse_loss(torch.std(img_stil_loss[1]), torch.std(output_loss[1])))
style_loss += 0.25 * (F.mse_loss(torch.mean(img_stil_loss[2]), torch.mean(output_loss[2])) + F.mse_loss(torch.std(img_stil_loss[2]), torch.std(output_loss[2])))
style_loss += 0.25 * (F.mse_loss(torch.mean(img_stil_loss[3]), torch.mean(output_loss[3])) + F.mse_loss(torch.std(img_stil_loss[3]), torch.std(output_loss[3])))
loss = content_loss + 0.5 * style_loss
loss.backward(retain_graph=True)
optimizer.step()