I trained a model using resnet50_fpn_v2 on my dataset but believe performance can be improved by using a SwinV2 backbone with FPN. I implemented this based on my understanding and the resnet50_fpn_v2 source code (comments are added for clarity);
class IntermediateLayerGetter(nn.ModuleDict):
# This is to get intermediate layer features (modified from the PyTorch source code)
def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")
orig_return_layers = return_layers
return_layers = {str(k): str(v) for k, v in return_layers.items()}
layers = OrderedDict()
for name, module in model.named_children():
layers[name] = module
if name in return_layers:
del return_layers[name]
if not return_layers:
self.return_layers = orig_return_layers
def forward(self, x):
out = OrderedDict()
for name, module in self.items():
x = module(x)
if name in self.return_layers:
out_name = self.return_layers[name]
# Here we permute the output so the channels are in the order FasterRCNN expects
out[out_name] = torch.permute(x, (0, 3, 1, 2))
return out
class BackboneWithFPN(nn.Module):
# This class is for implementing FPN backbone (also modified from the PyTorch source code)
def __init__(
backbone: nn.Module,
return_layers: Dict[str, str],
in_channels_list: List[int],
out_channels: int,
extra_blocks: Optional[ExtraFPNBlock] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
if extra_blocks is None:
extra_blocks = LastLevelMaxPool()
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.fpn = FeaturePyramidNetwork(
self.out_channels = out_channels
def forward(self, x: Tensor) -> Dict[str, Tensor]:
x = self.body(x)
x = self.fpn(x)
return x
class CustomSwin(nn.Module):
def __init__(self, backbone_model):
# Create a new OrderedDict to hold the layers
return_layers = OrderedDict()
# I get the features from layers 1-3-5-7, the layers before the patch embeddings
return_layers = {
'1': '0',
'3': '1',
'5': '2',
'7': '3'
# Define the in_channels for each layer (for SwinV2 small)
in_channels_list = [96, 192, 384, 768]
# Create a new Sequential module with the features
backbone_module = nn.Sequential(OrderedDict([
(f'{i}', layer) for i, layer in enumerate(backbone_model.features)
# Create the BackboneWithFPN
self.backbone = BackboneWithFPN(
self.out_channels = 256
def forward(self, x):
return self.backbone(x)
def load_backbone(trainable_layers=6):
# This is the vanilla version of swin_v2_s (imported from PyTorch library)
backbone = swin_v2_s(weights=Swin_V2_S_Weights.DEFAULT)
# Remove the classification head (norm, permute, avgpool, flatten, and head)
backbone.norm = nn.Identity()
backbone.permute = nn.Identity()
backbone.avgpool = nn.Identity()
backbone.flatten = nn.Identity()
backbone.head = nn.Identity()
# Freeze all parameters
for param in backbone.parameters():
param.requires_grad = False
# Unfreeze the last trainable_layers
for layer in list(backbone.features)[-trainable_layers:]:
for param in layer.parameters():
param.requires_grad = True
return backbone
# Load the backbone
backbone = load_backbone()
# Define anchor generator
anchor_generator = AnchorGenerator(
sizes=((32,), (64,), (128,), (256,), (512,)), # 5th for the pool layer
aspect_ratios=((0.5, 1.0, 2.0),) * 5 # Same aspect ratio for all feature maps
# Define RoI Pooler
roi_pooler = MultiScaleRoIAlign(
featmap_names=['0', '1', '2', '3'], # Ignore pool
output_size=(7, 7),
# Define the Faster R-CNN model
model = FasterRCNN(
# Replace the box predictor
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, len(CLASSES)).to(DEVICE)
Although I’m confident I implemented everything correctly, when I start training, the loss consistently gets stuck around 1.00. Also, when in validation mode, the model doesn’t make predictions with a confidence score higher than 0.5, even after the 50th epoch. This suggests there’s an issue in my implementation, but I can’t pinpoint the problem.
If anyone could review my code and help identify what might be going wrong, I would greatly appreciate it!