So, I have done everything the paper tells me to do, at least that’s what I think. I’ll not go too much into some details which are probably fine, because my MSE optimized network works just fine. The paper mentions the usage of output VGG feature maps as being: VGG22 (output from second conv block before the second maxpooling layer) and VGG54 (output from fifth conv block before the fourth maxpooling layer). The used VGG architecture is VGG19.
class VGG(nn.Module):
def __init__(self, bn=False, loss_config='VGG54', pretrained=True):
super(VGG, self).__init__()
if loss_config == 'VGG54':
if not bn:
model = models.vgg19(pretrained=pretrained).features[:36]
else:
model = models.vgg19_bn(pretrained=pretrained).features[:52]
elif loss_config == 'VGG22':
if not bn:
model = models.vgg19(pretrained=pretrained).features[:9]
else:
model = models.vgg19_bn(pretrained=pretrained).features[:13]
if pretrained:
for param in model.parameters():
param.requires_grad = False
self.model = model
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
std = torch.Tensor([0.229, 0.224, 0.225]).view(1,3,1,1)
self.register_buffer('mean', mean)
self.register_buffer('std', std)
def forward(self, x):
x = (x - self.mean) / self.std
x = self.model(x).div(12.75)
return x
This VGG has the required preprocessing, as PyTorch recommends. I think every other requirement is followed in this VGG class I created. The training loop is defined as follows:
SRimg = model_generator.forward(LRimg)
optim_generator.zero_grad()
if loss_function == 'VGG22' or loss_function == 'VGG54':
# The if/else block is to ensure size of 224 prior to feeding VGG
# Even though the interpolate is never used, it's just a fail safe
if HRimg.size()[2:] < (224, 224):
SRimg_vgg = nnf.interpolate(SRimg, size=(224, 224), mode='bicubic', align_corners=False)
HRimg_vgg = nnf.interpolate(HRimg, size=(224, 224), mode='bicubic', align_corners=False)
SRfeat = model_feat.forward(SRimg_vgg)
HRfeat = model_feat.forward(HRimg_vgg)
else:
SRfeat = model_feat.forward(SRimg)
HRfeat = model_feat.forward(HRimg)
loss_g = perceptual_loss(SRfeat, HRfeat) #+ tv_loss(SRimg)
else:
loss_g = perceptual_loss(SRimg, HRimg)
loss_g.backward()
optim_generator.step()
A few descriptions of variables:
model_generator = the generator itself, as described in the paper;
optim_generator = an Adam optimizer, as described in the paper;
nnf = torch.nn.functional;
model_feat = VGG class first described;
perceptual_loss = MSELoss()
The last else is just for use when VGG Loss is not being used.
If anything is unclear or if you need more information, let me know. I have done a bit of pruning to make the code easier to understand.