Unkonwn Probelm of bilinear CNN (not error, but problem in my implemention of algorithm)

Bilinear CNN is a powerful network in texture classification. http://vis-www.cs.umass.edu/bcnn/. Original codes are written in MatConvnet.
When I am trying to implement it using PyTorch, the accuracy rises to 35%(in epoch 10) and then decreases to 20%. The key operations are outer production, average pooling, signed sqrt and L2 normalization. Codes are as follows:

# The definition of Bilinear CNN
# input: [batch, channel, height, width]
class VggBasedNet_bilinear(nn.Module):
    def __init__(self, originalModel):
        super(VggBasedNet_bilinear, self).__init__()
        # feature extraction from Conv5_3 with relu
        self.features = nn.Sequential(*list(original_vgg16.features)[:-1]) 

        self.classifier = nn.Linear(512 * 512, args.numClasses)

    def forward(self, x):
        # feature extraction from Conv5_3 with relu
        x = self.features(x).view(-1,512,784)
        
        #  outer production of features on each position over height*width; average pooling
        x = torch.matmul(x, x.permute(0,2,1)).view(-1,512*512)/784.0

        # signed sqrt
        x = torch.mul(torch.sign(x),torch.sqrt(torch.abs(x)+1e-12)) 

        # L2 normalization
        x = F.normalize(x, p=2, dim=1)

        # final FC layer
        x = self.classifier(x)

        return x

I am sure that there is no wrong in rest codes because I only changed the network structure based on a VGG16 fine-tuning script.
Maybe you don’t know the bilinear CNN and how it works. It doesn’t matter. Is there any problem with above codes? Can they achieve their corresponding function?

@Liang

  1. A quick debug step would be to first freeze all the network parameters (model to eval and using volatile Variables), extract the ReLU_5 features and perform the bilinear pooling. Even without training, the outer-product and pooling of the relu_5 features gives a performance boost over the regular CNN features. If this works as expected, then you would have a sanity-check that at least your forward pass is working.

  2. Fix the batch-normalization parameters of the CNN.

  3. Train the classifier layers first, before fine-tuning the whole network.

  4. The optimizer should be SGD following the original paper.

Thanks for your help.
1, I fixed VGG’s parameters and trained the last layer. It outputs about 35%. Not as expected ( 60%-70%)
2, I use VGG16 without batch-normalization, which has the same structure with its original MatConvnet model.
3, I follow the official tw-step training: first to extract deep features, and then train the last layer. It helps and boost the accuracy.( But I don’t know why. In my mind, freezing VGG layers is equal to extracting features first and training the last layer later.) The accuracy of first stage is 59% (80.1% reported in the paper), the accuracy of the second stage is 78%(84.1% reported).
4, I use the SGD optimizer with momentum (same as the paper).
5, Interestingly, the Tensorflow version works well. (69% for the 1st stage and 84% for 2rd stage). Is the pretrained VGG16 model’s fault in different DL frameworks? I love Pytorch, but it does not work well this time…

Some side notes:

  • Did you preprocess your input images? It seems the tensorflow implementation does a zero-centering.
  • Did you initialize your new Linear layers? Some Xavier init and a small constant is used for the weights and bias, respectively. (It seems the default is xavier uniform, when no args are passed.)

1, Yes, I have preprocessed with the same mean and std.
2, Yes, I follow the tensorflow version and use the same initilizer.
By the way, I am converting the used pretrained VGG16 model in tensorflow to pytorch. I want to see the influence of the pretrained VGG.

That’s a good idea. Let us know, how it worked out!

It works a little bit with a gain of 5% to use the TensorFlow version VGG16 in Pytorch (Attention: You have to change the names of variables, dimension order of tensors and normalization of data.) I didn’t figure it out. And I am going to turn to Tensorflow for better model optimization.

    # L2 normalization
    x = F.normalize(x, p=2, dim=1)

I think F.normalize cannot implement the original L2 norm as same as the paper, the original L2 norm should be dim=None, which is not supported by F.normalize