Convnext weights

Hi there,

I’m training a Faster RCNN model with a Convnext backbone. That’s my model creation code:

backbone = torchvision.models.convnext_large(weights="DEFAULT").features
backbone.out_channels = 1536
anchor_generator = AnchorGenerator(
    sizes=((8, 16, 24, 32),),
    aspect_ratios=((0.5, 0.25, 1/6),)
)
roi_pooler = torchvision.ops.MultiScaleRoIAlign(
    featmap_names=['0'],
    output_size=7,
    sampling_ratio=2
)
model = FasterRCNN(
    backbone,
    num_classes=num_classes,
    rpn_anchor_generator=anchor_generator,
    box_roi_pool=roi_pooler
)

I followed this tutorial:
TorchVision Object Detection Finetuning Tutorial — PyTorch Tutorials 2.2.0+cu121 documentation

My problem is that during training it seems like the model begins the training with randomly initialized weights rather than pretrained weights. The mAP@50:95 begins from nearly zero. In contrast, if I train with a Resnet50 backbone with pretrained weights the mAP@50:95 increases much faster. I also tried training with Resnet50 with random weights and the increase in mAP@50:95 reminds me the training done with Convnext weights.
I should mention that I also tried inputting torchvision.models.convnext_large with pretrained=True, and it didn’t make a difference.

Does anyone have a guess about the origin of the problem? Training with Convnext in theory should reach higher mAP@50:95 but I get the opposite results.

David.

Hello David,
Could you try specifiying the weights this way:

convnext_weights = torchvision.models.ConvNeXt_Large_Weights.IMAGENET1K_V1
backbone = torchvision.models.convnext_large(weights=convnext_weights).features

For me, having a look at weights as image helped me sometimes to be sure to have them loaded:

    weights = model.features[0][0].weight.cpu().data.numpy()
    fig, axes = plt.subplots(8, weights.shape[0]//8)
    for cindex, ax in enumerate(axes.flatten()):
        image = (np.transpose(weights[cindex, ...], (1, 2, 0)) * SCALING + 125)
        image[image<0] = 0
        image[image>255] = 255
        ax.imshow(image.astype(np.uint8))
        ax.set_axis_off()
    plt.savefig(os.path.join(workding_directory, 'filter1_weight'))
    plt.close()