Swin Based Backbone for FasterRCNN

I trained a model using resnet50_fpn_v2 on my dataset but believe performance can be improved by using a SwinV2 backbone with FPN. I implemented this based on my understanding and the resnet50_fpn_v2 source code (comments are added for clarity);

class IntermediateLayerGetter(nn.ModuleDict):
    # This is to get intermediate layer features (modified from the PyTorch source code)
    def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
            raise ValueError("return_layers are not present in model")
        orig_return_layers = return_layers
        return_layers = {str(k): str(v) for k, v in return_layers.items()}
        layers = OrderedDict()
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layers:
                del return_layers[name]
            if not return_layers:
                break
        super().__init__(layers)
        self.return_layers = orig_return_layers

    def forward(self, x):
        out = OrderedDict()
        for name, module in self.items():
            x = module(x)
            if name in self.return_layers:
                out_name = self.return_layers[name]
                # Here we permute the output so the channels are in the order FasterRCNN expects
                out[out_name] = torch.permute(x, (0, 3, 1, 2))
        return out


class BackboneWithFPN(nn.Module):
    # This class is for implementing FPN backbone (also modified from the PyTorch source code)
    def __init__(
        self,
        backbone: nn.Module,
        return_layers: Dict[str, str],
        in_channels_list: List[int],
        out_channels: int,
        extra_blocks: Optional[ExtraFPNBlock] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
        if extra_blocks is None:
            extra_blocks = LastLevelMaxPool()
        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
        self.fpn = FeaturePyramidNetwork(
            in_channels_list=in_channels_list,
            out_channels=out_channels,
            extra_blocks=extra_blocks,
            norm_layer=norm_layer,
        )
        self.out_channels = out_channels

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        x = self.body(x)
        x = self.fpn(x)
        return x


class CustomSwin(nn.Module):
    def __init__(self, backbone_model):
        super().__init__()
        # Create a new OrderedDict to hold the layers
        return_layers = OrderedDict()
        # I get the features from layers 1-3-5-7, the layers before the patch embeddings
        return_layers = {
            '1': '0',
            '3': '1',
            '5': '2',
            '7': '3'
        }
        # Define the in_channels for each layer (for SwinV2 small)
        in_channels_list = [96, 192, 384, 768]
        # Create a new Sequential module with the features
        backbone_module = nn.Sequential(OrderedDict([
            (f'{i}', layer) for i, layer in enumerate(backbone_model.features)
        ]))
        # Create the BackboneWithFPN
        self.backbone = BackboneWithFPN(
            backbone_module,
            return_layers,
            in_channels_list,
            out_channels=256,
            extra_blocks=None
        )
        self.out_channels = 256

    def forward(self, x):
        return self.backbone(x)


def load_backbone(trainable_layers=6):
    # This is the vanilla version of swin_v2_s (imported from PyTorch library)
    backbone = swin_v2_s(weights=Swin_V2_S_Weights.DEFAULT)
    # Remove the classification head (norm, permute, avgpool, flatten, and head)
    backbone.norm = nn.Identity()
    backbone.permute = nn.Identity()
    backbone.avgpool = nn.Identity()
    backbone.flatten = nn.Identity()
    backbone.head = nn.Identity()
    # Freeze all parameters
    for param in backbone.parameters():
        param.requires_grad = False
    # Unfreeze the last trainable_layers
    for layer in list(backbone.features)[-trainable_layers:]:
        for param in layer.parameters():
            param.requires_grad = True
    return backbone


# Load the backbone
backbone = load_backbone()

# Define anchor generator
anchor_generator = AnchorGenerator(
    sizes=((32,), (64,), (128,), (256,), (512,)),  # 5th for the pool layer
    aspect_ratios=((0.5, 1.0, 2.0),) * 5  # Same aspect ratio for all feature maps
)

# Define RoI Pooler
roi_pooler = MultiScaleRoIAlign(
    featmap_names=['0', '1', '2', '3'],  # Ignore pool
    output_size=(7, 7),
    sampling_ratio=2
)

# Define the Faster R-CNN model
model = FasterRCNN(
    backbone,
    num_classes=len(CLASSES),
    rpn_anchor_generator=anchor_generator,
    box_roi_pool=roi_pooler,
    min_size=width,
    max_size=height,
).to(DEVICE)

# Replace the box predictor
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, len(CLASSES)).to(DEVICE)

Although I’m confident I implemented everything correctly, when I start training, the loss consistently gets stuck around 1.00. Also, when in validation mode, the model doesn’t make predictions with a confidence score higher than 0.5, even after the 50th epoch. This suggests there’s an issue in my implementation, but I can’t pinpoint the problem.

If anyone could review my code and help identify what might be going wrong, I would greatly appreciate it!