Training a Region Proposal Network with a Resnet-101 Backbone

Training Problems for a RPN

I am trying to train a network for region proposals as in the anchor box-concept
from Faster R-CNN.

I am using a pretrained Resnet 101 backbone with three layers popped off. The popped off
layers are the conv5_x layer, average pooling layer, and softmax layer.

As a result my convolutional feature map fed to the RPN heads for images
of size 600*600 results is of spatial resolution 37 by 37 with 1024 channels.

I have set the gradients of only block conv4_x to be trainable.
From there I am using the torchvision.models.detection rpn code to use the
rpn.AnchorGenerator, rpn.RPNHead, and ultimately rpn.RegionProposalNetwork classes.
There are two losses that are returned by the call to forward, the objectness loss,
and the regression loss.

The issue I am having is that my model is training very, very slowly. In Girschick’s original paper he says he trains over 80K minibatches (roughly 8 epochs since the Pascal VOC 2012 dataset has about 11000 images), where each mini batch is a single image with 256 anchor boxes, but my network from epoch to epoch improves its loss VERY SLOWLY, and I am training for 30 + epochs.

Below is my class code for the network.

class ResnetRegionProposalNetwork(torch.nn.Module):
    def __init__(self):
        super(ResnetRegionProposalNetwork, self).__init__()
        self.resnet_backbone = torch.nn.Sequential(*list(models.resnet101(pretrained=True).children())[:-3])
        non_trainable_backbone_layers = 5
        counter = 0
        for child in self.resnet_backbone:
            if counter < non_trainable_backbone_layers:
                for param in child.parameters():
                    param.requires_grad = False
                counter += 1
            else:
                break

        anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
        aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
        self.rpn_anchor_generator = rpn.AnchorGenerator(
            anchor_sizes, aspect_ratios
        )
        out_channels = 1024
        self.rpn_head = rpn.RPNHead(
            out_channels, self.rpn_anchor_generator.num_anchors_per_location()[0]
        )

        rpn_pre_nms_top_n = {"training": 2000, "testing": 1000}
        rpn_post_nms_top_n = {"training": 2000, "testing": 1000}
        rpn_nms_thresh = 0.7
        rpn_fg_iou_thresh = 0.7
        rpn_bg_iou_thresh = 0.2
        rpn_batch_size_per_image = 256
        rpn_positive_fraction = 0.5

        self.rpn = rpn.RegionProposalNetwork(
            self.rpn_anchor_generator, self.rpn_head,
            rpn_fg_iou_thresh, rpn_bg_iou_thresh,
            rpn_batch_size_per_image, rpn_positive_fraction,
            rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)

    def forward(self,
                images,       # type: ImageList
                targets=None  # type: Optional[List[Dict[str, Tensor]]]
                ):
        feature_maps = self.resnet_backbone(images)
        features = {"0": feature_maps}
        image_sizes = getImageSizes(images)
        image_list = il.ImageList(images, image_sizes)
        return self.rpn(image_list, features, targets)

I am using the adam optimizer with the following parameters:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, ResnetRPN.parameters()), lr=0.01, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

My training loop is here:

for epoch_num in range(epochs): # will train epoch number of times per execution of this program
        loss_per_epoch = 0.0
        dl_iterator = iter(P.getPascalVOC2012DataLoader())
        current_epoch = epoch + epoch_num
        saveModelDuringTraining(current_epoch, ResnetRPN, optimizer, running_loss)
        batch_number = 0
        for image_batch, ground_truth_box_batch in dl_iterator:
            #print(batch_number)
            optimizer.zero_grad()
            boxes, losses = ResnetRPN(image_batch, ground_truth_box_batch)
            losses = losses["loss_objectness"] + losses["loss_rpn_box_reg"]
            losses.backward()
            optimizer.step()
            running_loss += float(losses)
            batch_number += 1
            if batch_number % 100 == 0:  # print the loss on every batch of 100 images
                print('[%d, %5d] loss: %.3f' %
                      (current_epoch + 1, batch_number + 1, running_loss))
                string_to_print = "\n epoch number:" + str(epoch + 1) + ", batch number:" \
                                  + str(batch_number + 1) + ", running loss: " + str(running_loss)
                printToFile(string_to_print)
                loss_per_epoch += running_loss
                running_loss = 0.0
        print("finished Epoch with epoch loss " + str(loss_per_epoch))
        printToFile("Finished Epoch: " + str(epoch + 1) + " with epoch loss: " + str(loss_per_epoch))
        loss_per_epoch = 0.0

I am considering trying the following ideas to fix the network training very slowly:

  • trying various learning rates (although I have already tried 0.01, 0.001, 0.003 with similar results
  • various batch sizes (so far the best results have been batches of 4 (4 images * 256 anchors per image)
  • freezing more/less layers of the Resnet-101 backbone
  • using a different optimizer altogether
  • different weightings of the loss function

Any hints or things obviously wrong with my approach MUCH APPRECIATED. I would be happy to give any more information to anyone who can help.

4 Likes

Hi! I too want to trai RPN network with some pretrained Backbone for my own project, so I was wondering did you have any luck resolving your issue or maybe have any additional tips?

Great work by the way!

Hi Stefan. Yes I do actually have a lot of tips so far. I am learning as I go here, so some of these tips might be super obvious/useless for you. Hopefully nothing I’m saying is outright wrong, I’m on a learning journey so be kind to me internet (lol) but that being said…

Here are some thoughts I have had:

  • First of all note that openCV reads files in BGR format by default, which I embarrassingly learned after the fact and mistrained a few networks. Further Tensorflow’s decoding of .jpeg and .png files ( tf.image.decode_jpeg, tf.image.decode_png) both return an array in (height, width) not (width, height) format like I had thought. These pretrained models MUST take input in (3 x H x W) format (RGB).

  • When using Image resizing algorithms, use openCV to make 100 % sure that the result looks correct by drawing the resized boxes in the resized image. I’ve settled with 600 by 600 images with bilinear interpolation, cubic interpolation in tensorflow was giving me strange problems. I even wrote a de-normalization function out of paranoia that the normalization problem was messing with the inputs to see if it matched the original image after normalization (i.e invert the procedure). I think checking the inputs just cannot be stressed enough, especially if you are serializing the input, as I was with Tensorflows tf reocrds. There are a lot of sneaky problems that can occur along the way, so right before feeding your batches to a model, I would check the inputs manually and view the images in open CV with the boxes.

  • Second you should definitely use the image normalization described in the Pytorch vision models page if you want to use the image net pre-trained weights. I also wrote a script to find the mean r,g,b values and standard deviations for different datasets like the SUN RGBD dataset, Pascal 2007, 2012 if you want to train from scratch (i.e. not use Imagenet weights, but you still want to normalize channel info)

  • For the optimizer, almost all papers I am reading for RPNS (and generally for a lot of machine vision models) use some variant of Stochastic Gradient Descent with momentum, with a learning rate scheduler, where for the first portion of training the learning rate is either 0.001, or 0.003
    Something like :
    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, ResnetRPN.parameters()), lr=0.001, momentum=0.9, weight_decay=0.0005)
    After some amount of iteration, they decrease the learning rate to 0.0001 for fine tuning. I think this is pretty important. I’m not sure why using the adam optimizer with default settings was so bad, this could have been due to errors on my end, but I am just going to follow what I am reading in most of the papers I am checking out (e.g. Faster - RCNN, Fast RCNN, DP-FCN, R-FCN, etc.)

  • During training, its worth examining which loss (regression, objectness) is really spiking the loss. For me, it seems like on Pascal 2007, 2012 the objectness improved pretty rapidly, but the regression was really the main source of the error (even after re-weighting as in Faster-RCNN, i.e. dividing by 256 or the number of region proposals in a batch). Questions you can ask are: which anchors are really causing the problem? I.e. Can a smaller 32 by 32 anchor that is not very well centered on an object anchor to learn to grow to a massive 500 by 500 box that well?

  • The number of iterations in the original faster-RCNN paper seems a bit low to me, my own models benefitted from more iterations (up to 32 times on the same dataset with the 128 region proposals). This may not be good advice lol.

  • Testing the RPN after training visually is a good way to inspect what might be the issues (boxes all being the same size, missing very obvious edge/colour cues). I wrote a bunch of fake RPNS that do things like randomly sample box size and x_min,y_min location. Or randomly choose anchor boxes as in Faster R-CNN, sample from a normal distribution of box x_min, y_mins with certain heights and widths. I had a few other ones (making even guess in each quadrant, etc.), but you get the idea.
    I compared my RPN to these fake ones, and it did much better thankfully. Clearly since the random RPNS don’t use any image information, if the network can’t vastly outperform them you know something is seriously wrong. I was surprised how even with pretty bad training loss how good the network I trained on Pascal 2007 was (finding really small objects, PERFECT bounding boxes, overall good average IOU for higher ranked boxes). Take the 10 most highly ranked (by objectness) boxes and compute their max IOU with the ground truth boxes, if you consistently get IOUs over 0.7 you are in good territory. I’m thinking of writing an article “What Makes a Good Region Proposal Network” actually. The literature is pretty sparse on this topic, there aren’t really any testing metrics agreed upon for this task that I can see.

  • As for transfer learning, and which layers to freeze and all that jazz… So far I have just not been able to make transfer learning work that great…, but I have recently fixed a lot of problems, so I will get back to you on this one. My best results to date come from training from scratch on MS COCO 2014, 2015 with all layers at the same time, which is very slow, but worth it. All my layers were initialized uniformly at random.

  • There is also the option of popping off layers from the Resnet 101 backbone in order to get a larger feature map. The stride is quite large I find if you want to do things like position sensitive poolings, compute masks after, you definitely benefit from a larger feature map in some circumstances.

  • I have wrote a “dilated” Resnet 101 backbone that for various layer configurations, does not pop off layers, and in this way you can still keep the layer configuration and choose based on the size of the feature maps you want by keeping the stride in bottleneck downsample sequential model to be 1 for the 1 by1 conv2d thats used to downsample the image (before batch normalization); the dilation is then set to 2 (although I’d like to experiment with 3,5,7). This may be trickier when using pre-trained weights however since that kernel was not originally trained with any dilation used. I have yet to test these dilated networks, but I will.

Hopefully this helps.

2 Likes

Thanks for the tips! Really appreciate it. By the way, did you try using SGD with Cyclical Learning rate? Maybe this could help boost your performance.

Can you give me full code training, sir?
Thanks a lot!

my email: long.ln181599@gmail.com

For object detection? Just download the MS COCO dataset or Pascal VOC to play around with.

1 Like

Many thanks for your effort!
Could you please send me the full code for training RPN?
Thank you so much! zcoguz@gmail.com

I would have to dig through some pretty old code at this point. I recommend just studying the torchvision implementation of RPN here:

and pairing that with whatever backbone you want, the input to the RPN is a dict of feature maps indexed by levels, which are strings, i.e. the output per “level” of your backbone. The levels typically have the same number of channels, but differing spatial resolutions, and are used for feature pyramid network like implementations. The output during training, at index 1 since its a tuple, is the regression and objectness losses. Simply write a training loop with a dataloader of your desired dataset and train the RPN as you would any regression-based deep learning model. Its up to you to weight the objectness and regression losses.

1 Like