VGG Feature Maps Rescaled

So, I have been reading a paper saying that VGG feature maps were rescaled by 1/12.75. I have two questions regarding this:

  1. I assume this just means that all weights and biases have been rescaled, am I right?

  2. The original paper was done using Theano and Lasagne. I wonder if the rescale factor would be different when using the pretrained VGG in PyTorch.

Could you post a link to the paper or quote the section of the paper?
I’m not sure, if rescaling refers to a multiplicative scale factor or to reshaping the spatial size.

Hey ptrblck, thanks for your time. Here is the paper: https://arxiv.org/abs/1609.04802
You can check the rescale part in the experiments description.

Thanks for the reference!
It should be a multiplicative factor then.

So back to the questions.

  1. No, I think the activation maps (outputs of conv layers) are rescaled and the loss is calculated using equation 5.

  2. It shouldn’t be different, so I would recommend to use it as described in the paper, if you want to reproduce it.

Thanks for your input ptrblck. I originally started this thread trying to find an issue on how I was reproducing the paper, because I have been getting result images with checkerboard artifacts. As an example, I provide you three images: the first is the original, the second is a super-resolution based on MSE Loss and the third is a super-resolution based on VGG Loss. This problem appears only when optimizing the network with the perceptual loss function based on VGG feature maps, as described in the paper. When optimized for MSE loss, as you can see in the shown picture, the expected results and metrics are achieved. Would you have some insight into this issue? I can share some of the code, if that helps, just let me know.

That’s interesting. Sure, feel free to post the code so that we can have a look at it.

So, I have done everything the paper tells me to do, at least that’s what I think. I’ll not go too much into some details which are probably fine, because my MSE optimized network works just fine. The paper mentions the usage of output VGG feature maps as being: VGG22 (output from second conv block before the second maxpooling layer) and VGG54 (output from fifth conv block before the fourth maxpooling layer). The used VGG architecture is VGG19.

class VGG(nn.Module):
    def __init__(self, bn=False, loss_config='VGG54', pretrained=True):
        super(VGG, self).__init__()
        if loss_config == 'VGG54':
            if not bn:
                model = models.vgg19(pretrained=pretrained).features[:36]
            else:
                model = models.vgg19_bn(pretrained=pretrained).features[:52]

        elif loss_config == 'VGG22':
            if not bn:
                model = models.vgg19(pretrained=pretrained).features[:9]
            else:
                model = models.vgg19_bn(pretrained=pretrained).features[:13]

        if pretrained:
            for param in model.parameters():
                param.requires_grad = False

        self.model = model
        mean = torch.Tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
        std = torch.Tensor([0.229, 0.224, 0.225]).view(1,3,1,1)
        self.register_buffer('mean', mean)
        self.register_buffer('std', std)

    def forward(self, x):
        x = (x - self.mean) / self.std
        x = self.model(x).div(12.75)
        return x

This VGG has the required preprocessing, as PyTorch recommends. I think every other requirement is followed in this VGG class I created. The training loop is defined as follows:

SRimg = model_generator.forward(LRimg)
optim_generator.zero_grad()
if loss_function == 'VGG22' or loss_function == 'VGG54':
# The if/else block is to ensure size of 224 prior to feeding VGG
# Even though the interpolate is never used, it's just a fail safe
    if HRimg.size()[2:] < (224, 224):
        SRimg_vgg = nnf.interpolate(SRimg, size=(224, 224), mode='bicubic', align_corners=False)
        HRimg_vgg = nnf.interpolate(HRimg, size=(224, 224), mode='bicubic', align_corners=False)
        SRfeat = model_feat.forward(SRimg_vgg)
        HRfeat = model_feat.forward(HRimg_vgg)
    else:
        SRfeat = model_feat.forward(SRimg)
        HRfeat = model_feat.forward(HRimg)
        loss_g = perceptual_loss(SRfeat, HRfeat) #+ tv_loss(SRimg)
else:
    loss_g = perceptual_loss(SRimg, HRimg)
    loss_g.backward()
    optim_generator.step()

A few descriptions of variables:
model_generator = the generator itself, as described in the paper;
optim_generator = an Adam optimizer, as described in the paper;
nnf = torch.nn.functional;
model_feat = VGG class first described;
perceptual_loss = MSELoss()

The last else is just for use when VGG Loss is not being used.

If anything is unclear or if you need more information, let me know. I have done a bit of pruning to make the code easier to understand.

In your approach you are only rescaling the final output activation of the base vgg model, while it seems that multiple activation maps are scaled and summed to the vgg loss in the paper:

We define the VGG loss based on the ReLU activation layers of the pre-trained 19 layer VGG network described in Simonyan and Zisserman [49]. With φi,j we indicate the feature map obtained by the j-th convolution (after activation) before the i-th maxpooling layer within the VGG19 network, which we consider given. We then define the VGG loss as the euclidean distance between the feature
representations of a reconstructed image GθG (I_LR) and the
reference image I_HR

You could get all intermediate activation maps using e.g. forward hooks or manipulating the forward of the vgg model.

1 Like

I appreciate the feedback, ptrblck, and I’m sorry for my delay. After some research, I decided to implement VGG based on this gist: https://gist.github.com/alper111/8233cdb0414b4cb5853f2f730ab95a49

So I have modified mine to something similar to the following:

class VGG(nn.Module):
    def __init__(self, bn=False, loss_config='VGG54', pretrained=True):
        super(VGG, self).__init__()
        blocks = []
        if loss_config == 'VGG54':
            if not bn:
                blocks.append(models.vgg19(pretrained=pretrained).features[:4].eval())
                blocks.append(models.vgg19(pretrained=pretrained).features[4:9].eval())
                blocks.append(models.vgg19(pretrained=pretrained).features[9:18].eval())
                blocks.append(models.vgg19(pretrained=pretrained).features[18:27].eval())
                blocks.append(models.vgg19(pretrained=pretrained).features[27:36].eval())
            else:
                blocks.append(models.vgg19(pretrained=pretrained).features[:6].eval())
                blocks.append(models.vgg19(pretrained=pretrained).features[6:13].eval())
                blocks.append(models.vgg19(pretrained=pretrained).features[13:26].eval())
                blocks.append(models.vgg19(pretrained=pretrained).features[26:39].eval())
                blocks.append(models.vgg19(pretrained=pretrained).features[39:52].eval())

        elif loss_config == 'VGG22':
            if not bn:
                blocks.append(models.vgg19(pretrained=pretrained).features[:4].eval())
                blocks.append(models.vgg19(pretrained=pretrained).features[4:9].eval())
            else:
                blocks.append(models.vgg19(pretrained=pretrained).features[:6].eval())
                blocks.append(models.vgg19(pretrained=pretrained).features[6:13].eval())
        
        blocks = nn.ModuleList(blocks)

        if pretrained:
            for param in blocks.parameters():
                param.requires_grad = False        

        self.blocks = blocks

        mean = torch.Tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
        std = torch.Tensor([0.229, 0.224, 0.225]).view(1,3,1,1)
        self.register_buffer('mean', mean)
        self.register_buffer('std', std)

    def forward(self, x):
        output = []
        x = (x - self.mean) / self.std
        for block in self.blocks:
            x = block(x).div(12.75)
            output.append(x)
        return output

In this class I’m able to output a list of feature maps from the selected VGG architecture, each feature map output comes from after activation of convolution block and before maxpooling layer, as the paper states. Differently from the gist I linked, I chose to generate the feature maps in this class and calculating the loss separately, which works best in my code.

I think this goes accordingly to what you said and to some other information related to perceptual losses papers, am I right? I’ll briefly train the network and share the results as soon as I can.

1 Like

I have only trained for a fraction of an epoch, just to take a glimpse on what difference did the changes make. This is the image I was able to generate:

It seems like a huge improvement over the last one, which was fully trained. Probably what you pointed out is the solution for my case. In the future I’ll share the image of the fully trained network, to confirm it is really working. Thanks a lot for your support.

1 Like

One more thing. The paper says that it was used a 96 crop size. Pytorch’s VGG pretrained models need to have at least 224 crop size. What would you do? Using 224 severely increases training time, but I’m unsure what it would mean to forward a smaller image to the VGG in this case. From what I see, I have three options:

  1. Train with 96 crop size (faster);

  2. Train with 224 crop size (slower);

  3. Use interpolate to generate a 224 patch given a 96 patch (medium).

It might be worth trying out the smaller (and faster) inputs, if that’s suggested by the authors.
What would point 3 mean? Would you crop a 96x96 patch and then reshape it to 224x224?

Yes, but on second thought it would be far from ideal.

I finally could solve my issue by generating a more complex loss function.

By reading again a comment from ptrblck, where he states that I should get the intermediate activation maps and add them to my loss function, I was able to generate plausible results. I’m not sure if it was my misunderstanding or if the related papers are not sufficiently clear regarding this information, but it seems very important that not only you optimize your model for the feature maps of interest. Optimizing the model for the previous feature maps contribute greatly to the final result.

Finally, I would like to state that I have read many tutorials and implementations of perceptual loss in the mean time. The curious thing is that the solution now used by me and previously suggested by ptrblck, is only used for style transfer, at least from what I have read. For super-resolution, the examples I have seen suggests that my first approach (only computing loss using the feature maps of interest) is the correct one, but I highly doubt that. Thanks ptrblck for your support and I hope my experience help others in the future.

1 Like

Thanks for getting back and congrats on the success in training your model!

In that case I would highly recommend to think about writing a blog post, tutorial, notebook etc.
Your approach might be useful for others, who are running into the same issues (and are not seeing your posts here).

1 Like

I will consider writing about it as soon as I have some more spare time. Thanks for the tip!

1 Like