I am working on a GAN for super-resolution at the moment (SRGAN). This GAN’s generator is initialized as a previously trained CNN. In summary, one should train the generator first, and then train the GAN with a decent generator initialization. I have trained the CNN and its results are as expected. Now, when training the GAN, I am not getting the expected results and the generated images introduced some artifacts like this:
Super-resolution image (CNN)
Super-resolution image (GAN)
The training setup is, in summary:
- Train the generator (optimized for MSELoss);
- Train the GAN (optimized for Perceptual Loss, as described in the paper), initializing the generator with the results from the first step. The optimizer here is the same from 1. (Adam, LR=0.001).
Among other things, I am thinking it might be a problem with the generator optimizer. After training the CNN separately, I didn’t save the optimizer’s state_dict and started training the GAN with a new optimizer, but with the same parameters. I thought this was right because the generator should be optimized with a different loss function, but I’m not sure. Does someone have an input on why this problem is happening? Do you think it can be related to the optimizer or anything else?
During training gan ? Are you using only adversarial loss ? Ideally you should train mse + adversarial + vgg together. And what kind of discriminator are you using ?
I am using VGG + Adversarial for the GAN. I don’t use MSE for the GAN because the paper states a choice between MSE and VGG loss, which I followed. Also, the discriminator is exactly the same as described in the paper:
I’m trying to be as close as possible to what the paper says.
This artifacts might be generated if you are using only adversarial loss to update both generator and discriminator. Can you use adversarial, mse and perceptual together to update the weights once. Check pix2pix, dc-gan update rules and follow them.
I was using Content Loss (features) and Adversarial Loss to train the GAN. In the paper, the author calculates the Content Loss either by comparing the generated image (SR image) with the original image (HR image), using MSE Loss, or by comparing the feature maps generated when feeding the low-resolution image (LR image) and HR image through the VGG, also using MSE Loss.
On the other hand, it is not clear if the Content Loss is exclusive. I mean, one may ask if Content Loss for the GAN is calculated only using the VGG feature maps (Case 1) or if it uses the VGG features maps AND the MSELoss(SRimage, HRimage) (Case 2).
I have tested briefly using both for the Content Loss. Training with the described Case 1 generates the bad images I originally mentioned (those with the artifacts and low PSNR). The Case 2 on the other hand looks promising, as it holds to an PSNR level similar to the reported in the paper, but I didn’t finish training because it seems to me this was not originally done by the author, because of the following statement in the paper:
Instead of relying on pixel-wise losses we build on the ideas of Gatys et al. , Bruna et al.  and Johnson et al.  and use a loss function that is closer to perceptual similarity. We define the VGG loss based on the ReLU activation layers of the pre-trained 19 layer VGG network described in Simonyan and Zisserman .
This statement makes me think that Case 1 is the right one, even though I’m not reaching the expected results just yet.
Also, maybe I’m being too strict and the instead just means that it will not rely only on pixel-wise losses, but pixel-wise and VGG loss, which would make the Case 2 the right one. Let me know what you think, any input is good, this is a very important work for me.
Case 1 and case 2 behaviour is expected. Please send the paper title. Will have a look.
I posted it originally. Quoting it:
I am working on a GAN for super-resolution at the moment (SRGAN).
Thanks for helping out.
Went through the paper. Yeah, they seem to propose vgg loss as an alternative to mse loss. Since mse is directly related to psnr, it would be ideal to train it with mse. But in this case, since we are worried more about the perceptual quality they are using vgg perceptual loss. So, that’s the reason they have used mean opinion score measures. Now, coming to making it work, it will be better to have a smaller patch (96x96) as mentioned in the paper. Train for higher number of epochs compared to the one you do with mse, because there is no direct correlation. I have also seen similar works which tries to bring in a indirect relation to pixel-level mapping, but i had hard time making it work.
Yes, that’s pretty much what I’m doing. Recently I have changed the image patch size to 256 instead of 96. I did this because of PyTorch’s pretrained VGG19, as it asks for a input of size higher than 224, but it did not improve much my results, so I guess the problem is in my code. I’ll be reviewing the code as a whole tomorrow, in hope of finding a mistake and fixing it.
The good part is that I can reach the reported metrics of SRResNet-MSE, which makes me suspect that the problem is related to my VGG Loss implementation.
Cool. Which framework are you using ? Pytorch ?
Share the code. I will see if i can help.
Sorry for the delay, I was basically rewriting the code. It seems a decent version to be shared for now. I will be writing a README tomorrow, which will help you training the model/testing it. I will also upload a pretrained SRResNet tomorrow. You can control the training parameters by editing “config/train.json” and you will need to manually change the folder output of testing in the test.py. I’m training it on aproximately 400k images from Imagenet. You will need to download the eval datasets too, but I’ll upload them tomorrow, when I’ll have some free time. The eval datasets I’m using at the moment are: Set5, Set14, BSD100. Urban100 and Manga109. Share with me what you think. The SRResNet-MSE seems just fine, and SRResNet-VGG22 looks better, but there’s still room to improve. I haven’t tested SRGAN yet. If you have any feedback, mainly regarding SRResNet-VGG22, please share with me.
Also, the PyTorch pretrained models of VGG requires an input of, at least, 3x224x224. I didn’t think about it, it occurred to me a few minutes ago, but I need to forward 96x96 images through this VGG model. What would you do? I’m thinking about rescaling the image to 224, forwarding in the VGG, and then downscale the feature map, but it seems odd to me, I will be losing information. Best case scenario would be training VGG from scratch I think, with the 3x96x96 inputs I need.