I am using the a Faster RCNN model to learn an object detection task with one class of objects (plus the background).
- The size of images is 1000x1000x3.
- After training, the bounding boxes are always in the range [0,300]
- Even after resizing the images to any other size, the predicted bounding boxes are always in this range
I need help with the following questions:
- Why is this happening?
- Is there a way to avoid this? Nobody else seems to be having this problem, so is there anything that I am supposed to do while constructing the model, or at any other step?
Code for loading and training the model:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 2 # 1 class (insulator) + background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
model.transform = torchvision.models.detection.transform.GeneralizedRCNNTransform(min_size=1000,
max_size=1000,
image_mean=[0.485, 0.456, 0.406],
image_std=[0.229, 0.224, 0.225])
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=1e-2, momentum=0.9, weight_decay=0.005)
num_epochs = 10
for epoch in range(num_epochs):
train_one_epoch(model, optimizer, data_loader_train, torch.device('cpu'), epoch, print_freq=10)
As mentioned, I have changed the size to different combinations, like [500x500], [800x800], [500x800], [800x500]. The result is still always the same - the predicted bounding boxes are in the range [0,300].
A sample set of predicted bounding boxes:
tensor([[108, 8, 140, 128],
[ 39, 213, 61, 62],
[ 36, 172, 56, 11],
[ 97, 113, 115, 193],
[247, 148, 8, 234],
[229, 220, 254, 74],
[ 28, 86, 44, 141],
[213, 122, 233, 177],
[138, 100, 209, 128],
[239, 219, 1, 61],
[ 44, 172, 58, 253],
[ 29, 176, 63, 56],
[227, 220, 244, 62],
[ 4, 240, 26, 12],
[ 37, 177, 56, 213],
[ 40, 187, 60, 225],
[ 90, 114, 107, 188],
[ 75, 112, 118, 186]], dtype=torch.uint8)
The ground truth for this sample:
tensor([[367., 782., 394., 899.],
[552., 716., 573., 829.],
[488., 466., 508., 583.],
[807., 427., 823., 524.],
[760., 661., 777., 750.],
[866., 624., 879., 706.],
[797., 858., 812., 910.]])